0705
This commit is contained in:
parent
423d345fdf
commit
bbb3c3b7df
@ -0,0 +1,28 @@
|
||||
"""add_version_to_settlements
|
||||
|
||||
Revision ID: 071ac4268ccb
|
||||
Revises: be770eea8ec2
|
||||
Create Date: 2025-05-07 23:47:27.788572
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '071ac4268ccb'
|
||||
down_revision: Union[str, None] = 'be770eea8ec2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
@ -0,0 +1,28 @@
|
||||
"""add_version_to_lists_table
|
||||
|
||||
Revision ID: 64a6614cb156
|
||||
Revises: 071ac4268ccb
|
||||
Create Date: 2025-05-08 00:48:49.027570
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '64a6614cb156'
|
||||
down_revision: Union[str, None] = '071ac4268ccb'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
@ -0,0 +1,28 @@
|
||||
"""add_expense_split_settlement_tables_and_relations
|
||||
|
||||
Revision ID: 8c2c0f83e2b9
|
||||
Revises: d53eedd151b7
|
||||
Create Date: 2025-05-07 23:30:48.621512
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '8c2c0f83e2b9'
|
||||
down_revision: Union[str, None] = 'd53eedd151b7'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
@ -0,0 +1,28 @@
|
||||
"""add_version_to_settlements
|
||||
|
||||
Revision ID: be770eea8ec2
|
||||
Revises: 8c2c0f83e2b9
|
||||
Create Date: 2025-05-07 23:41:26.669049
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'be770eea8ec2'
|
||||
down_revision: Union[str, None] = '8c2c0f83e2b9'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
pass
|
@ -1,4 +1,3 @@
|
||||
# app/api/v1/api.py
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.endpoints import health
|
||||
@ -10,10 +9,11 @@ from app.api.v1.endpoints import lists
|
||||
from app.api.v1.endpoints import items
|
||||
from app.api.v1.endpoints import ocr
|
||||
from app.api.v1.endpoints import costs
|
||||
from app.api.v1.endpoints import financials
|
||||
|
||||
api_router_v1 = APIRouter()
|
||||
|
||||
api_router_v1.include_router(health.router) # Path /health defined inside
|
||||
api_router_v1.include_router(health.router)
|
||||
api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
|
||||
@ -22,5 +22,6 @@ api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
|
||||
api_router_v1.include_router(items.router, tags=["Items"])
|
||||
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
|
||||
api_router_v1.include_router(costs.router, tags=["Costs"])
|
||||
api_router_v1.include_router(financials.router)
|
||||
# Add other v1 endpoint routers here later
|
||||
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
|
@ -1,15 +1,17 @@
|
||||
# app/api/v1/endpoints/costs.py
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.api.dependencies import get_current_user
|
||||
from app.models import User as UserModel # For get_current_user dependency
|
||||
from app.schemas.cost import ListCostSummary
|
||||
from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model
|
||||
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
|
||||
from app.crud import cost as crud_cost
|
||||
from app.crud import list as crud_list # For permission checking
|
||||
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError
|
||||
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
@ -67,3 +69,60 @@ async def get_list_cost_summary(
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
|
||||
|
||||
@router.get(
|
||||
"/groups/{group_id}/balance-summary",
|
||||
response_model=GroupBalanceSummary,
|
||||
summary="Get Detailed Balance Summary for a Group",
|
||||
tags=["Costs", "Groups"],
|
||||
responses={
|
||||
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
|
||||
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
|
||||
}
|
||||
)
|
||||
async def get_group_balance_summary(
|
||||
group_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieves a detailed financial balance summary for all users within a specific group.
|
||||
It considers all expenses, their splits, and all settlements recorded for the group.
|
||||
The user must be a member of the group to view its balance summary.
|
||||
"""
|
||||
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
|
||||
|
||||
# 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar)
|
||||
# Assuming a function like this exists in app.crud.group or we add it.
|
||||
# For now, let's placeholder this check logic.
|
||||
# await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id)
|
||||
# A simpler check for now: fetch the group and see if user is part of member_associations
|
||||
group_check = await db.execute(
|
||||
select(GroupModel)
|
||||
.options(selectinload(GroupModel.member_associations))
|
||||
.where(GroupModel.id == group_id)
|
||||
)
|
||||
db_group_for_check = group_check.scalars().first()
|
||||
|
||||
if not db_group_for_check:
|
||||
raise GroupNotFoundError(group_id)
|
||||
|
||||
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
|
||||
if not user_is_member:
|
||||
# If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
|
||||
|
||||
# 2. Calculate the group balance summary
|
||||
try:
|
||||
balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id)
|
||||
logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}")
|
||||
return balance_summary
|
||||
except GroupNotFoundError as e:
|
||||
logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except UserNotFoundError as e: # Should not happen if group members are correctly fetched
|
||||
logger.error(f"User not found during balance summary for group {group_id}: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.")
|
442
be/app/api/v1/endpoints/financials.py
Normal file
442
be/app/api/v1/endpoints/financials.py
Normal file
@ -0,0 +1,442 @@
|
||||
# app/api/v1/endpoints/financials.py
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import List as PyList, Optional, Sequence
|
||||
|
||||
from app.database import get_db
|
||||
from app.api.dependencies import get_current_user
|
||||
from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum
|
||||
from app.schemas.expense import (
|
||||
ExpenseCreate, ExpensePublic,
|
||||
SettlementCreate, SettlementPublic,
|
||||
ExpenseUpdate, SettlementUpdate
|
||||
)
|
||||
from app.crud import expense as crud_expense
|
||||
from app.crud import settlement as crud_settlement
|
||||
from app.crud import group as crud_group
|
||||
from app.crud import list as crud_list
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError, GroupNotFoundError, UserNotFoundError,
|
||||
InvalidOperationError, GroupPermissionError, ListPermissionError,
|
||||
ItemNotFoundError, GroupMembershipError
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# --- Helper for permissions ---
|
||||
async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"):
|
||||
try:
|
||||
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_member=True)
|
||||
except ListPermissionError as e:
|
||||
logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}")
|
||||
raise ListPermissionError(list_id, action=action)
|
||||
except ListNotFoundError:
|
||||
raise
|
||||
|
||||
# --- Expense Endpoints ---
|
||||
@router.post(
|
||||
"/expenses",
|
||||
response_model=ExpensePublic,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create New Expense",
|
||||
tags=["Expenses"]
|
||||
)
|
||||
async def create_new_expense(
|
||||
expense_in: ExpenseCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
|
||||
effective_group_id = expense_in.group_id
|
||||
is_group_context = False
|
||||
|
||||
if expense_in.list_id:
|
||||
# Check basic access to list (implies membership if list is in group)
|
||||
await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for")
|
||||
list_obj = await db.get(ListModel, expense_in.list_id)
|
||||
if not list_obj:
|
||||
raise ListNotFoundError(expense_in.list_id)
|
||||
if list_obj.group_id:
|
||||
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
||||
raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.")
|
||||
effective_group_id = list_obj.group_id
|
||||
is_group_context = True # Expense is tied to a group via the list
|
||||
elif expense_in.group_id:
|
||||
raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.")
|
||||
# If list is personal, no group check needed yet, handled by payer check below.
|
||||
|
||||
elif effective_group_id: # Only group_id provided for expense
|
||||
is_group_context = True
|
||||
# Ensure user is at least a member to create expense in group context
|
||||
await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for")
|
||||
else:
|
||||
# This case should ideally be caught by earlier checks if list_id was present but list was personal.
|
||||
# If somehow reached, it means no list_id and no group_id.
|
||||
raise InvalidOperationError("Expense must be linked to a list_id or group_id.")
|
||||
|
||||
# Finalize expense payload with correct group_id if derived
|
||||
expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id})
|
||||
|
||||
# --- Granular Permission Check for Payer ---
|
||||
if expense_in_final.paid_by_user_id != current_user.id:
|
||||
logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}")
|
||||
# If creating expense paid by someone else, user MUST be owner IF in group context
|
||||
if is_group_context and effective_group_id:
|
||||
try:
|
||||
await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user")
|
||||
except GroupPermissionError as e:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}")
|
||||
else:
|
||||
# Cannot create expense paid by someone else for a personal list (no group context)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.")
|
||||
# If paying for self, basic list/group membership check above is sufficient.
|
||||
|
||||
try:
|
||||
created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id)
|
||||
logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.")
|
||||
return created_expense
|
||||
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except NotImplementedError as e:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
|
||||
async def get_expense(
|
||||
expense_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
|
||||
expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||
if not expense:
|
||||
raise ItemNotFoundError(item_id=expense_id)
|
||||
|
||||
if expense.list_id:
|
||||
await check_list_access_for_financials(db, expense.list_id, current_user.id)
|
||||
elif expense.group_id:
|
||||
await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id)
|
||||
elif expense.paid_by_user_id != current_user.id:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense")
|
||||
return expense
|
||||
|
||||
@router.get("/lists/{list_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a List", tags=["Expenses", "Lists"])
|
||||
async def list_list_expenses(
|
||||
list_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} listing expenses for list ID {list_id}")
|
||||
await check_list_access_for_financials(db, list_id, current_user.id)
|
||||
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
|
||||
return expenses
|
||||
|
||||
@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"])
|
||||
async def list_group_expenses(
|
||||
group_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
|
||||
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
|
||||
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
|
||||
return expenses
|
||||
|
||||
@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"])
|
||||
async def update_expense_details(
|
||||
expense_id: int,
|
||||
expense_in: ExpenseUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Updates an existing expense (description, currency, expense_date only).
|
||||
Requires the current version number for optimistic locking.
|
||||
User must have permission to modify the expense (e.g., be the payer or group admin).
|
||||
"""
|
||||
logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})")
|
||||
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||
if not expense_db:
|
||||
raise ItemNotFoundError(item_id=expense_id)
|
||||
|
||||
# --- Granular Permission Check ---
|
||||
can_modify = False
|
||||
# 1. User paid for the expense
|
||||
if expense_db.paid_by_user_id == current_user.id:
|
||||
can_modify = True
|
||||
# 2. OR User is owner of the group the expense belongs to
|
||||
elif expense_db.group_id:
|
||||
try:
|
||||
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses")
|
||||
can_modify = True
|
||||
logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}")
|
||||
except GroupMembershipError: # User not even a member
|
||||
pass # Keep can_modify as False
|
||||
except GroupPermissionError: # User is member but not owner
|
||||
pass # Keep can_modify as False
|
||||
except GroupNotFoundError: # Group doesn't exist (data integrity issue)
|
||||
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.")
|
||||
pass # Keep can_modify as False
|
||||
# Note: If expense is only linked to a personal list (no group), only payer can modify.
|
||||
|
||||
if not can_modify:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)")
|
||||
|
||||
try:
|
||||
updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in)
|
||||
logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.")
|
||||
return updated_expense
|
||||
except InvalidOperationError as e:
|
||||
# Check if it's a version conflict (409) or other validation error (400)
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if "version" in str(e).lower():
|
||||
status_code = status.HTTP_409_CONFLICT
|
||||
raise HTTPException(status_code=status_code, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"])
|
||||
async def delete_expense_record(
|
||||
expense_id: int,
|
||||
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Deletes an expense and its associated splits.
|
||||
Requires expected_version query parameter for optimistic locking.
|
||||
User must have permission to delete the expense (e.g., be the payer or group admin).
|
||||
"""
|
||||
logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})")
|
||||
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
|
||||
if not expense_db:
|
||||
# Return 204 even if not found, as the end state is achieved (item is gone)
|
||||
logger.warning(f"Attempt to delete non-existent expense ID {expense_id}")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
# Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404
|
||||
|
||||
# --- Granular Permission Check ---
|
||||
can_delete = False
|
||||
# 1. User paid for the expense
|
||||
if expense_db.paid_by_user_id == current_user.id:
|
||||
can_delete = True
|
||||
# 2. OR User is owner of the group the expense belongs to
|
||||
elif expense_db.group_id:
|
||||
try:
|
||||
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses")
|
||||
can_delete = True
|
||||
logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}")
|
||||
except GroupMembershipError:
|
||||
pass
|
||||
except GroupPermissionError:
|
||||
pass
|
||||
except GroupNotFoundError:
|
||||
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.")
|
||||
pass
|
||||
|
||||
if not can_delete:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)")
|
||||
|
||||
try:
|
||||
await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version)
|
||||
logger.info(f"Expense ID {expense_id} deleted successfully.")
|
||||
# No need to return content on 204
|
||||
except InvalidOperationError as e:
|
||||
# Check if it's a version conflict (409) or other validation error (400)
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if "version" in str(e).lower():
|
||||
status_code = status.HTTP_409_CONFLICT
|
||||
raise HTTPException(status_code=status_code, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# --- Settlement Endpoints ---
|
||||
@router.post(
|
||||
"/settlements",
|
||||
response_model=SettlementPublic,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Record New Settlement",
|
||||
tags=["Settlements"]
|
||||
)
|
||||
async def create_new_settlement(
|
||||
settlement_in: SettlementCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
|
||||
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in")
|
||||
try:
|
||||
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement")
|
||||
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement")
|
||||
except GroupMembershipError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}")
|
||||
except GroupNotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
try:
|
||||
created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id)
|
||||
logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.")
|
||||
return created_settlement
|
||||
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
|
||||
async def get_settlement(
|
||||
settlement_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
|
||||
settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||
if not settlement:
|
||||
raise ItemNotFoundError(item_id=settlement_id)
|
||||
|
||||
is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id]
|
||||
try:
|
||||
await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id)
|
||||
except GroupMembershipError:
|
||||
if not is_party_to_settlement:
|
||||
raise GroupMembershipError(settlement.group_id, action="view this settlement's details")
|
||||
logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.")
|
||||
return settlement
|
||||
|
||||
@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"])
|
||||
async def list_group_settlements(
|
||||
group_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
|
||||
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group")
|
||||
settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit)
|
||||
return settlements
|
||||
|
||||
@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"])
|
||||
async def update_settlement_details(
|
||||
settlement_id: int,
|
||||
settlement_in: SettlementUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Updates an existing settlement (description, settlement_date only).
|
||||
Requires the current version number for optimistic locking.
|
||||
User must have permission (e.g., be involved party or group admin).
|
||||
"""
|
||||
logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})")
|
||||
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||
if not settlement_db:
|
||||
raise ItemNotFoundError(item_id=settlement_id)
|
||||
|
||||
# --- Granular Permission Check ---
|
||||
can_modify = False
|
||||
# 1. User is involved party (payer or payee)
|
||||
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
|
||||
if is_party:
|
||||
can_modify = True
|
||||
# 2. OR User is owner of the group the settlement belongs to
|
||||
# Note: Settlements always have a group_id based on current model
|
||||
elif settlement_db.group_id:
|
||||
try:
|
||||
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements")
|
||||
can_modify = True
|
||||
logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}")
|
||||
except GroupMembershipError:
|
||||
pass
|
||||
except GroupPermissionError:
|
||||
pass
|
||||
except GroupNotFoundError:
|
||||
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.")
|
||||
pass
|
||||
|
||||
if not can_modify:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)")
|
||||
|
||||
try:
|
||||
updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in)
|
||||
logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.")
|
||||
return updated_settlement
|
||||
except InvalidOperationError as e:
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if "version" in str(e).lower():
|
||||
status_code = status.HTTP_409_CONFLICT
|
||||
raise HTTPException(status_code=status_code, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"])
|
||||
async def delete_settlement_record(
|
||||
settlement_id: int,
|
||||
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Deletes a settlement.
|
||||
Requires expected_version query parameter for optimistic locking.
|
||||
User must have permission (e.g., be involved party or group admin).
|
||||
"""
|
||||
logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})")
|
||||
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
|
||||
if not settlement_db:
|
||||
logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}")
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# --- Granular Permission Check ---
|
||||
can_delete = False
|
||||
# 1. User is involved party (payer or payee)
|
||||
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
|
||||
if is_party:
|
||||
can_delete = True
|
||||
# 2. OR User is owner of the group the settlement belongs to
|
||||
elif settlement_db.group_id:
|
||||
try:
|
||||
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements")
|
||||
can_delete = True
|
||||
logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}")
|
||||
except GroupMembershipError:
|
||||
pass
|
||||
except GroupPermissionError:
|
||||
pass
|
||||
except GroupNotFoundError:
|
||||
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.")
|
||||
pass
|
||||
|
||||
if not can_delete:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)")
|
||||
|
||||
try:
|
||||
await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version)
|
||||
logger.info(f"Settlement ID {settlement_id} deleted successfully.")
|
||||
except InvalidOperationError as e:
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
if "version" in str(e).lower():
|
||||
status_code = status.HTTP_409_CONFLICT
|
||||
raise HTTPException(status_code=status_code, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
# TODO (remaining from original list):
|
||||
# (None - GET/POST/PUT/DELETE implemented for Expense/Settlement)
|
@ -88,6 +88,14 @@ class UserNotFoundError(HTTPException):
|
||||
detail=detail_msg
|
||||
)
|
||||
|
||||
class InvalidOperationError(HTTPException):
|
||||
"""Raised when an operation is invalid or disallowed by business logic."""
|
||||
def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class DatabaseConnectionError(HTTPException):
|
||||
"""Raised when there is an error connecting to the database."""
|
||||
def __init__(self):
|
||||
|
@ -1,83 +0,0 @@
|
||||
# be/tests/core/test_gemini.py
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Store original key if exists, then clear it for testing missing key scenario
|
||||
original_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
|
||||
# --- Test Module Import ---
|
||||
# This forces the module-level initialization code in gemini.py to run
|
||||
# We need to reload modules because settings might have been cached
|
||||
from importlib import reload
|
||||
from app.config import settings as app_settings
|
||||
from app.core import gemini as gemini_core
|
||||
|
||||
# Reload settings first to ensure GEMINI_API_KEY is None initially
|
||||
reload(app_settings)
|
||||
# Reload gemini core to trigger initialization logic with potentially missing key
|
||||
reload(gemini_core)
|
||||
|
||||
|
||||
def test_gemini_initialization_without_key():
|
||||
"""Verify behavior when GEMINI_API_KEY is not set."""
|
||||
# Reload modules again to ensure clean state for this specific test
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
reload(app_settings)
|
||||
reload(gemini_core)
|
||||
|
||||
assert gemini_core.gemini_flash_client is None
|
||||
assert gemini_core.gemini_initialization_error is not None
|
||||
assert "GEMINI_API_KEY not configured" in gemini_core.gemini_initialization_error
|
||||
|
||||
with pytest.raises(RuntimeError, match="GEMINI_API_KEY not configured"):
|
||||
gemini_core.get_gemini_client()
|
||||
|
||||
@patch('google.generativeai.configure')
|
||||
@patch('google.generativeai.GenerativeModel')
|
||||
def test_gemini_initialization_with_key(mock_generative_model: MagicMock, mock_configure: MagicMock):
|
||||
"""Verify initialization logic is called when key is present (using mocks)."""
|
||||
# Set a dummy key in the environment for this test
|
||||
test_key = "TEST_API_KEY_123"
|
||||
os.environ["GEMINI_API_KEY"] = test_key
|
||||
|
||||
# Reload settings and gemini module to pick up the new key
|
||||
reload(app_settings)
|
||||
reload(gemini_core)
|
||||
|
||||
# Assertions
|
||||
mock_configure.assert_called_once_with(api_key=test_key)
|
||||
mock_generative_model.assert_called_once_with(
|
||||
model_name="gemini-1.5-flash-latest",
|
||||
safety_settings=pytest.ANY, # Check safety settings were passed (ANY allows flexibility)
|
||||
# generation_config=pytest.ANY # Check if you added default generation config
|
||||
)
|
||||
assert gemini_core.gemini_flash_client is not None
|
||||
assert gemini_core.gemini_initialization_error is None
|
||||
|
||||
# Test get_gemini_client() success path
|
||||
client = gemini_core.get_gemini_client()
|
||||
assert client is not None # Should return the mocked client instance
|
||||
|
||||
# Clean up environment variable after test
|
||||
if original_api_key:
|
||||
os.environ["GEMINI_API_KEY"] = original_api_key
|
||||
else:
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
# Reload modules one last time to restore state for other tests
|
||||
reload(app_settings)
|
||||
reload(gemini_core)
|
||||
|
||||
# Restore original key after all tests in the module run (if needed)
|
||||
def teardown_module(module):
|
||||
if original_api_key:
|
||||
os.environ["GEMINI_API_KEY"] = original_api_key
|
||||
else:
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
reload(app_settings)
|
||||
reload(gemini_core)
|
@ -1,86 +0,0 @@
|
||||
# Example: be/tests/core/test_security.py
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from jose import jwt, JWTError
|
||||
import time
|
||||
|
||||
from app.core.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
verify_access_token,
|
||||
)
|
||||
from app.config import settings # Import settings for testing JWT config
|
||||
|
||||
# --- Password Hashing Tests ---
|
||||
|
||||
def test_hash_password_returns_string():
|
||||
password = "testpassword"
|
||||
hashed = hash_password(password)
|
||||
assert isinstance(hashed, str)
|
||||
assert password != hashed # Ensure it's not plain text
|
||||
|
||||
def test_verify_password_correct():
|
||||
password = "correct_password"
|
||||
hashed = hash_password(password)
|
||||
assert verify_password(password, hashed) is True
|
||||
|
||||
def test_verify_password_incorrect():
|
||||
hashed = hash_password("correct_password")
|
||||
assert verify_password("wrong_password", hashed) is False
|
||||
|
||||
def test_verify_password_invalid_hash_format():
|
||||
# Passlib's verify handles many format errors gracefully
|
||||
assert verify_password("any_password", "invalid_hash_string") is False
|
||||
|
||||
|
||||
# --- JWT Tests ---
|
||||
|
||||
def test_create_access_token():
|
||||
subject = "testuser@example.com"
|
||||
token = create_access_token(subject=subject)
|
||||
assert isinstance(token, str)
|
||||
|
||||
# Decode manually for basic check (verification done in verify_access_token tests)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
assert payload["sub"] == subject
|
||||
assert "exp" in payload
|
||||
assert isinstance(payload["exp"], int)
|
||||
|
||||
def test_verify_access_token_valid():
|
||||
subject = "test_subject_valid"
|
||||
token = create_access_token(subject=subject)
|
||||
payload = verify_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == subject
|
||||
|
||||
def test_verify_access_token_invalid_signature():
|
||||
subject = "test_subject_invalid_sig"
|
||||
token = create_access_token(subject=subject)
|
||||
# Attempt to verify with a wrong key
|
||||
wrong_key = settings.SECRET_KEY + "wrong"
|
||||
with pytest.raises(JWTError): # Decoding with wrong key should raise JWTError internally
|
||||
jwt.decode(token, wrong_key, algorithms=[settings.ALGORITHM])
|
||||
# Our verify function should catch this and return None
|
||||
assert verify_access_token(token + "tamper") is None # Tampering token often invalidates sig
|
||||
# Note: Testing verify_access_token directly returning None for wrong key is tricky
|
||||
# as the error happens *during* jwt.decode. We rely on it catching JWTError.
|
||||
|
||||
def test_verify_access_token_expired():
|
||||
# Create a token that expires almost immediately
|
||||
subject = "test_subject_expired"
|
||||
expires_delta = timedelta(seconds=-1) # Expired 1 second ago
|
||||
token = create_access_token(subject=subject, expires_delta=expires_delta)
|
||||
|
||||
# Wait briefly just in case of timing issues, though negative delta should guarantee expiry
|
||||
time.sleep(0.1)
|
||||
|
||||
# Decoding expired token raises ExpiredSignatureError internally
|
||||
with pytest.raises(JWTError): # Specifically ExpiredSignatureError, but JWTError catches it
|
||||
jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
# Our verify function should catch this and return None
|
||||
assert verify_access_token(token) is None
|
||||
|
||||
def test_verify_access_token_malformed():
|
||||
assert verify_access_token("this.is.not.a.valid.token") is None
|
@ -1,24 +1,49 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from sqlalchemy import func as sql_func, or_
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List as PyList, Dict, Set
|
||||
from typing import Dict, Optional, Sequence, List as PyList
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
|
||||
from app.schemas.cost import ListCostSummary, UserCostShare
|
||||
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
|
||||
from app.models import (
|
||||
List as ListModel,
|
||||
Item as ItemModel,
|
||||
User as UserModel,
|
||||
UserGroup as UserGroupModel,
|
||||
Group as GroupModel,
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
Settlement as SettlementModel
|
||||
)
|
||||
from app.schemas.cost import (
|
||||
ListCostSummary,
|
||||
UserCostShare,
|
||||
GroupBalanceSummary,
|
||||
UserBalanceDetail,
|
||||
SuggestedSettlement
|
||||
)
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
UserNotFoundError,
|
||||
GroupNotFoundError,
|
||||
InvalidOperationError
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
|
||||
"""
|
||||
Calculates the cost summary for a given list, splitting costs equally
|
||||
among relevant users (group members if list is in a group, or creator if personal).
|
||||
Calculates the cost summary for a given list based purely on item prices and who added them.
|
||||
This is a simpler calculation and does not involve the Expense/Settlement system.
|
||||
"""
|
||||
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
|
||||
list_result = await db.execute(
|
||||
select(ListModel)
|
||||
.options(
|
||||
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
|
||||
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
|
||||
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
|
||||
selectinload(ListModel.creator)
|
||||
)
|
||||
.where(ListModel.id == list_id)
|
||||
)
|
||||
@ -27,90 +52,215 @@ async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCos
|
||||
if not db_list:
|
||||
raise ListNotFoundError(list_id)
|
||||
|
||||
# 2. Determine participating users
|
||||
participating_users: Dict[int, UserModel] = {}
|
||||
participating_users_map: Dict[int, UserModel] = {}
|
||||
if db_list.group:
|
||||
# If list is part of a group, all group members participate
|
||||
for ug_assoc in db_list.group.user_associations:
|
||||
if ug_assoc.user: # Ensure user object is loaded
|
||||
participating_users[ug_assoc.user.id] = ug_assoc.user
|
||||
else:
|
||||
# If personal list, only the creator participates (or if items were added by others somehow, include them)
|
||||
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
|
||||
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
|
||||
creator_user = await db.get(UserModel, db_list.created_by_id)
|
||||
if not creator_user:
|
||||
# This case should ideally not happen if data integrity is maintained
|
||||
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
|
||||
participating_users[creator_user.id] = creator_user
|
||||
if ug_assoc.user:
|
||||
participating_users_map[ug_assoc.user.id] = ug_assoc.user
|
||||
elif db_list.creator: # Personal list
|
||||
participating_users_map[db_list.creator.id] = db_list.creator
|
||||
|
||||
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
|
||||
# Include all users who added items with prices, even if not in the primary context (group/creator)
|
||||
for item in db_list.items:
|
||||
if item.added_by_user and item.added_by_user.id not in participating_users:
|
||||
participating_users[item.added_by_user.id] = item.added_by_user
|
||||
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
|
||||
participating_users_map[item.added_by_user.id] = item.added_by_user
|
||||
|
||||
|
||||
num_participating_users = len(participating_users)
|
||||
num_participating_users = len(participating_users_map)
|
||||
if num_participating_users == 0:
|
||||
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
|
||||
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
|
||||
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id,
|
||||
list_name=db_list.name,
|
||||
total_list_cost=Decimal("0.00"),
|
||||
num_participating_users=0,
|
||||
equal_share_per_user=Decimal("0.00"),
|
||||
user_balances=[]
|
||||
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
|
||||
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
|
||||
)
|
||||
|
||||
|
||||
# 3. Calculate total cost and items_added_value for each user
|
||||
total_list_cost = Decimal("0.00")
|
||||
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
|
||||
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
|
||||
|
||||
for item in db_list.items:
|
||||
if item.price is not None and item.price > Decimal("0"):
|
||||
total_list_cost += item.price
|
||||
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
|
||||
if item.added_by_id in participating_users_map:
|
||||
user_items_added_value[item.added_by_id] += item.price
|
||||
# If item.added_by_id is not in participating_users (e.g. user left group),
|
||||
# their contribution still counts to total cost, but they aren't part of the split.
|
||||
# The current logic adds item adders to participating_users, so this else is less likely.
|
||||
|
||||
# 4. Calculate equal share per user
|
||||
# Using ROUND_HALF_UP to handle cents appropriately.
|
||||
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
|
||||
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
|
||||
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
|
||||
|
||||
# 5. For each user, calculate their balance
|
||||
user_balances: PyList[UserCostShare] = []
|
||||
for user_id, user_obj in participating_users.items():
|
||||
first_user_processed = False
|
||||
for user_id, user_obj in participating_users_map.items():
|
||||
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
|
||||
balance = items_added - equal_share_per_user
|
||||
|
||||
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
|
||||
current_user_share = equal_share_per_user
|
||||
if not first_user_processed and remainder != Decimal("0"):
|
||||
current_user_share += remainder
|
||||
first_user_processed = True
|
||||
|
||||
balance = items_added - current_user_share
|
||||
user_identifier = user_obj.name if user_obj.name else user_obj.email
|
||||
user_balances.append(
|
||||
UserCostShare(
|
||||
user_id=user_id,
|
||||
user_identifier=user_identifier,
|
||||
user_id=user_id, user_identifier=user_identifier,
|
||||
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
amount_due=equal_share_per_user, # Already quantized
|
||||
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
)
|
||||
)
|
||||
|
||||
# Sort user_balances for consistent output, e.g., by user_id or identifier
|
||||
user_balances.sort(key=lambda x: x.user_identifier)
|
||||
|
||||
|
||||
# 6. Return the populated ListCostSummary schema
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id,
|
||||
list_name=db_list.name,
|
||||
list_id=db_list.id, list_name=db_list.name,
|
||||
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
num_participating_users=num_participating_users,
|
||||
equal_share_per_user=equal_share_per_user, # Already quantized
|
||||
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
user_balances=user_balances
|
||||
)
|
||||
|
||||
# --- Helper for Settlement Suggestions ---
|
||||
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
|
||||
"""
|
||||
Calculates a list of suggested settlements to resolve group debts.
|
||||
Uses a greedy algorithm to minimize the number of transactions.
|
||||
Input: List of UserBalanceDetail objects with calculated net_balances.
|
||||
Output: List of SuggestedSettlement objects.
|
||||
"""
|
||||
# Use a small tolerance for floating point comparisons with Decimal
|
||||
tolerance = Decimal("0.001")
|
||||
|
||||
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
|
||||
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
|
||||
|
||||
settlements: PyList[SuggestedSettlement] = []
|
||||
debtor_idx = 0
|
||||
creditor_idx = 0
|
||||
|
||||
# Create mutable copies of balances to track remaining amounts
|
||||
debtor_balances = {d.user_id: d.net_balance for d in debtors}
|
||||
creditor_balances = {c.user_id: c.net_balance for c in creditors}
|
||||
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
|
||||
|
||||
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
|
||||
debtor = debtors[debtor_idx]
|
||||
creditor = creditors[creditor_idx]
|
||||
|
||||
debtor_remaining = debtor_balances[debtor.user_id]
|
||||
creditor_remaining = creditor_balances[creditor.user_id]
|
||||
|
||||
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
|
||||
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
|
||||
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
if transfer_amount > tolerance: # Only record meaningful transfers
|
||||
settlements.append(SuggestedSettlement(
|
||||
from_user_id=debtor.user_id,
|
||||
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
|
||||
to_user_id=creditor.user_id,
|
||||
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
|
||||
amount=transfer_amount
|
||||
))
|
||||
|
||||
# Update remaining balances
|
||||
debtor_balances[debtor.user_id] += transfer_amount
|
||||
creditor_balances[creditor.user_id] -= transfer_amount
|
||||
|
||||
# Move to next debtor if current one is settled (or very close)
|
||||
if abs(debtor_balances[debtor.user_id]) < tolerance:
|
||||
debtor_idx += 1
|
||||
|
||||
# Move to next creditor if current one is settled (or very close)
|
||||
if creditor_balances[creditor.user_id] < tolerance:
|
||||
creditor_idx += 1
|
||||
|
||||
# Log if lists aren't empty - indicates potential imbalance or rounding issue
|
||||
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
|
||||
# Calculate remaining balances for logging
|
||||
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
|
||||
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
|
||||
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
|
||||
|
||||
return settlements
|
||||
|
||||
# --- NEW: Detailed Group Balance Summary ---
|
||||
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
|
||||
"""
|
||||
Calculates a detailed balance summary for all users in a group,
|
||||
considering all expenses, splits, and settlements within that group.
|
||||
Also calculates suggested settlements.
|
||||
"""
|
||||
group = await db.get(GroupModel, group_id)
|
||||
if not group:
|
||||
raise GroupNotFoundError(group_id)
|
||||
|
||||
# 1. Get all group members
|
||||
group_members_result = await db.execute(
|
||||
select(UserModel)
|
||||
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
|
||||
.where(UserGroupModel.group_id == group_id)
|
||||
)
|
||||
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
|
||||
if not group_members:
|
||||
return GroupBalanceSummary(
|
||||
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
|
||||
)
|
||||
|
||||
user_balances_data: Dict[int, UserBalanceDetail] = {}
|
||||
for user_id, user_obj in group_members.items():
|
||||
user_balances_data[user_id] = UserBalanceDetail(
|
||||
user_id=user_id,
|
||||
user_identifier=user_obj.name if user_obj.name else user_obj.email
|
||||
)
|
||||
|
||||
overall_total_expenses = Decimal("0.00")
|
||||
overall_total_settlements = Decimal("0.00")
|
||||
|
||||
# 2. Process Expenses and ExpenseSplits for the group
|
||||
expenses_result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.group_id == group_id)
|
||||
.options(selectinload(ExpenseModel.splits))
|
||||
)
|
||||
for expense in expenses_result.scalars().all():
|
||||
overall_total_expenses += expense.total_amount
|
||||
if expense.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
|
||||
|
||||
for split in expense.splits:
|
||||
if split.user_id in user_balances_data:
|
||||
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
|
||||
|
||||
# 3. Process Settlements for the group
|
||||
settlements_result = await db.execute(
|
||||
select(SettlementModel).where(SettlementModel.group_id == group_id)
|
||||
)
|
||||
for settlement in settlements_result.scalars().all():
|
||||
overall_total_settlements += settlement.amount
|
||||
if settlement.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
|
||||
if settlement.paid_to_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
|
||||
|
||||
# 4. Calculate net balances and prepare final list
|
||||
final_user_balances: PyList[UserBalanceDetail] = []
|
||||
for user_id in group_members.keys():
|
||||
data = user_balances_data[user_id]
|
||||
data.net_balance = (
|
||||
data.total_paid_for_expenses + data.total_settlements_received
|
||||
) - (data.total_share_of_expenses + data.total_settlements_paid)
|
||||
|
||||
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
final_user_balances.append(data)
|
||||
|
||||
final_user_balances.sort(key=lambda x: x.user_identifier)
|
||||
|
||||
# 5. Calculate suggested settlements (NEW)
|
||||
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
||||
|
||||
return GroupBalanceSummary(
|
||||
group_id=group.id,
|
||||
group_name=group.name,
|
||||
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
user_balances=final_user_balances,
|
||||
suggested_settlements=suggested_settlements # Add suggestions to response
|
||||
)
|
592
be/app/crud/expense.py
Normal file
592
be/app/crud/expense.py
Normal file
@ -0,0 +1,592 @@
|
||||
# app/crud/expense.py
|
||||
import logging # Add logging import
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
||||
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
|
||||
from datetime import datetime, timezone # Added timezone
|
||||
|
||||
from app.models import (
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
User as UserModel,
|
||||
List as ListModel,
|
||||
Group as GroupModel,
|
||||
UserGroup as UserGroupModel,
|
||||
SplitTypeEnum,
|
||||
Item as ItemModel
|
||||
)
|
||||
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
|
||||
from app.core.exceptions import (
|
||||
# Using existing specific exceptions where possible
|
||||
ListNotFoundError,
|
||||
GroupNotFoundError,
|
||||
UserNotFoundError,
|
||||
InvalidOperationError # Import the new exception
|
||||
)
|
||||
|
||||
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
|
||||
# This should be a proper HTTPException subclass if used in API layer
|
||||
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
|
||||
# pass
|
||||
|
||||
logger = logging.getLogger(__name__) # Initialize logger
|
||||
|
||||
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
|
||||
"""
|
||||
Determines the list of users an expense should be split amongst.
|
||||
Priority: Group members (if group_id), then List's group members or creator (if list_id).
|
||||
Fallback to only the payer if no other context yields users.
|
||||
"""
|
||||
users_to_split_with: PyList[UserModel] = []
|
||||
processed_user_ids = set()
|
||||
|
||||
async def _add_user(user: Optional[UserModel]):
|
||||
if user and user.id not in processed_user_ids:
|
||||
users_to_split_with.append(user)
|
||||
processed_user_ids.add(user.id)
|
||||
|
||||
if expense_group_id:
|
||||
group_result = await db.execute(
|
||||
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
|
||||
.where(GroupModel.id == expense_group_id)
|
||||
)
|
||||
group = group_result.scalars().first()
|
||||
if not group:
|
||||
raise GroupNotFoundError(expense_group_id)
|
||||
for assoc in group.member_associations:
|
||||
await _add_user(assoc.user)
|
||||
|
||||
elif expense_list_id: # Only if group_id was not primary context
|
||||
list_result = await db.execute(
|
||||
select(ListModel)
|
||||
.options(
|
||||
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
||||
selectinload(ListModel.creator)
|
||||
)
|
||||
.where(ListModel.id == expense_list_id)
|
||||
)
|
||||
db_list = list_result.scalars().first()
|
||||
if not db_list:
|
||||
raise ListNotFoundError(expense_list_id)
|
||||
|
||||
if db_list.group:
|
||||
for assoc in db_list.group.member_associations:
|
||||
await _add_user(assoc.user)
|
||||
elif db_list.creator:
|
||||
await _add_user(db_list.creator)
|
||||
|
||||
if not users_to_split_with:
|
||||
payer_user = await db.get(UserModel, expense_paid_by_user_id)
|
||||
if not payer_user:
|
||||
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
|
||||
raise UserNotFoundError(user_id=expense_paid_by_user_id)
|
||||
await _add_user(payer_user)
|
||||
|
||||
if not users_to_split_with:
|
||||
# This should ideally not be reached if payer is always a fallback
|
||||
raise InvalidOperationError("Could not determine any users for splitting the expense.")
|
||||
|
||||
return users_to_split_with
|
||||
|
||||
|
||||
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
|
||||
"""Creates a new expense and its associated splits.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
expense_in: Expense creation data
|
||||
current_user_id: ID of the user creating the expense
|
||||
|
||||
Returns:
|
||||
The created expense with splits
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If payer or split users don't exist
|
||||
ListNotFoundError: If specified list doesn't exist
|
||||
GroupNotFoundError: If specified group doesn't exist
|
||||
InvalidOperationError: For various validation failures
|
||||
"""
|
||||
# Helper function to round decimals consistently
|
||||
def round_money(amount: Decimal) -> Decimal:
|
||||
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
# 1. Context Validation
|
||||
# Validate basic context requirements first
|
||||
if not expense_in.list_id and not expense_in.group_id:
|
||||
raise InvalidOperationError("Expense must be associated with a list or a group.")
|
||||
|
||||
# 2. User Validation
|
||||
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
|
||||
|
||||
# 3. List/Group Context Resolution
|
||||
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||
|
||||
# 4. Create the expense object
|
||||
db_expense = ExpenseModel(
|
||||
description=expense_in.description,
|
||||
total_amount=round_money(expense_in.total_amount),
|
||||
currency=expense_in.currency or "USD",
|
||||
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
||||
split_type=expense_in.split_type,
|
||||
list_id=expense_in.list_id,
|
||||
group_id=final_group_id,
|
||||
item_id=expense_in.item_id,
|
||||
paid_by_user_id=expense_in.paid_by_user_id,
|
||||
created_by_user_id=current_user_id # Track who created this expense
|
||||
)
|
||||
|
||||
# 5. Generate splits based on split type
|
||||
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
|
||||
|
||||
# 6. Single transaction for expense and all splits
|
||||
try:
|
||||
db.add(db_expense)
|
||||
await db.flush() # Get expense ID without committing
|
||||
|
||||
# Update all splits with the expense ID
|
||||
for split in splits_to_create:
|
||||
split.expense_id = db_expense.id
|
||||
|
||||
db.add_all(splits_to_create)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
|
||||
raise InvalidOperationError(f"Failed to save expense: {str(e)}")
|
||||
|
||||
# Refresh to get the splits relationship populated
|
||||
await db.refresh(db_expense, attribute_names=["splits"])
|
||||
return db_expense
|
||||
|
||||
|
||||
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
|
||||
"""Resolves and validates the expense's context (list and group).
|
||||
|
||||
Returns the final group_id for the expense after validation.
|
||||
"""
|
||||
final_group_id = expense_in.group_id
|
||||
|
||||
# If list_id is provided, validate it and potentially derive group_id
|
||||
if expense_in.list_id:
|
||||
list_obj = await db.get(ListModel, expense_in.list_id)
|
||||
if not list_obj:
|
||||
raise ListNotFoundError(expense_in.list_id)
|
||||
|
||||
# If list belongs to a group, verify consistency or inherit group_id
|
||||
if list_obj.group_id:
|
||||
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
||||
raise InvalidOperationError(
|
||||
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
|
||||
f"but expense was specified for group {expense_in.group_id}."
|
||||
)
|
||||
final_group_id = list_obj.group_id # Prioritize list's group
|
||||
|
||||
# If only group_id is provided (no list_id), validate group_id
|
||||
elif final_group_id:
|
||||
group_obj = await db.get(GroupModel, final_group_id)
|
||||
if not group_obj:
|
||||
raise GroupNotFoundError(final_group_id)
|
||||
|
||||
return final_group_id
|
||||
|
||||
|
||||
async def _generate_expense_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Generates appropriate expense splits based on split type."""
|
||||
|
||||
splits_to_create: PyList[ExpenseSplitModel] = []
|
||||
|
||||
# Create splits based on the split type
|
||||
if expense_in.split_type == SplitTypeEnum.EQUAL:
|
||||
splits_to_create = await _create_equal_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
|
||||
splits_to_create = await _create_exact_amount_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
|
||||
splits_to_create = await _create_percentage_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.SHARES:
|
||||
splits_to_create = await _create_shares_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
|
||||
splits_to_create = await _create_item_based_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
|
||||
|
||||
if not splits_to_create:
|
||||
raise InvalidOperationError("No expense splits were generated.")
|
||||
|
||||
return splits_to_create
|
||||
|
||||
|
||||
async def _create_equal_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates equal splits among users."""
|
||||
|
||||
users_for_splitting = await get_users_for_splitting(
|
||||
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
|
||||
)
|
||||
if not users_for_splitting:
|
||||
raise InvalidOperationError("No users found for EQUAL split.")
|
||||
|
||||
num_users = len(users_for_splitting)
|
||||
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
|
||||
remainder = db_expense.total_amount - (amount_per_user * num_users)
|
||||
|
||||
splits = []
|
||||
for i, user in enumerate(users_for_splitting):
|
||||
split_amount = amount_per_user
|
||||
if i == 0 and remainder != Decimal('0'):
|
||||
split_amount = round_money(amount_per_user + remainder)
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=user.id,
|
||||
owed_amount=split_amount
|
||||
))
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_exact_amount_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits with exact amounts."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
|
||||
|
||||
# Validate all users in splits exist
|
||||
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||
|
||||
current_total = Decimal("0.00")
|
||||
splits = []
|
||||
|
||||
for split_in in expense_in.splits_in:
|
||||
if split_in.owed_amount <= Decimal('0'):
|
||||
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
|
||||
|
||||
rounded_amount = round_money(split_in.owed_amount)
|
||||
current_total += rounded_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=split_in.user_id,
|
||||
owed_amount=rounded_amount
|
||||
))
|
||||
|
||||
if round_money(current_total) != db_expense.total_amount:
|
||||
raise InvalidOperationError(
|
||||
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
|
||||
)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_percentage_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on percentages."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
|
||||
|
||||
# Validate all users in splits exist
|
||||
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||
|
||||
total_percentage = Decimal("0.00")
|
||||
current_total = Decimal("0.00")
|
||||
splits = []
|
||||
|
||||
for split_in in expense_in.splits_in:
|
||||
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
|
||||
raise InvalidOperationError(
|
||||
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
|
||||
)
|
||||
|
||||
total_percentage += split_in.share_percentage
|
||||
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
|
||||
current_total += owed_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=split_in.user_id,
|
||||
owed_amount=owed_amount,
|
||||
share_percentage=split_in.share_percentage
|
||||
))
|
||||
|
||||
if round_money(total_percentage) != Decimal("100.00"):
|
||||
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
|
||||
|
||||
# Adjust for rounding differences
|
||||
if current_total != db_expense.total_amount and splits:
|
||||
diff = db_expense.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_shares_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on shares."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
raise InvalidOperationError("Splits data is required for SHARES split type.")
|
||||
|
||||
# Validate all users in splits exist
|
||||
await _validate_users_in_splits(db, expense_in.splits_in)
|
||||
|
||||
# Calculate total shares
|
||||
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
|
||||
if total_shares == 0:
|
||||
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
|
||||
|
||||
splits = []
|
||||
current_total = Decimal("0.00")
|
||||
|
||||
for split_in in expense_in.splits_in:
|
||||
if not (split_in.share_units and split_in.share_units > 0):
|
||||
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
|
||||
|
||||
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
|
||||
owed_amount = round_money(db_expense.total_amount * share_ratio)
|
||||
current_total += owed_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=split_in.user_id,
|
||||
owed_amount=owed_amount,
|
||||
share_units=split_in.share_units
|
||||
))
|
||||
|
||||
# Adjust for rounding differences
|
||||
if current_total != db_expense.total_amount and splits:
|
||||
diff = db_expense.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_item_based_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on items in a shopping list."""
|
||||
|
||||
if not expense_in.list_id:
|
||||
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
|
||||
|
||||
if expense_in.splits_in:
|
||||
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
|
||||
|
||||
# Build query to fetch relevant items
|
||||
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
|
||||
if expense_in.item_id:
|
||||
items_query = items_query.where(ItemModel.id == expense_in.item_id)
|
||||
else:
|
||||
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
|
||||
|
||||
# Load items with their adders
|
||||
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
|
||||
relevant_items = items_result.scalars().all()
|
||||
|
||||
if not relevant_items:
|
||||
error_msg = (
|
||||
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
|
||||
if expense_in.item_id else
|
||||
f"List {expense_in.list_id} has no priced items to base the expense on."
|
||||
)
|
||||
raise InvalidOperationError(error_msg)
|
||||
|
||||
# Aggregate owed amounts by user
|
||||
calculated_total = Decimal("0.00")
|
||||
user_owed_amounts = defaultdict(Decimal)
|
||||
processed_items = 0
|
||||
|
||||
for item in relevant_items:
|
||||
if item.price is None or item.price <= Decimal("0"):
|
||||
if expense_in.item_id:
|
||||
raise InvalidOperationError(
|
||||
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
|
||||
)
|
||||
continue
|
||||
|
||||
if not item.added_by_user:
|
||||
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
|
||||
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
|
||||
|
||||
calculated_total += item.price
|
||||
user_owed_amounts[item.added_by_user.id] += item.price
|
||||
processed_items += 1
|
||||
|
||||
if processed_items == 0:
|
||||
raise InvalidOperationError(
|
||||
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
|
||||
)
|
||||
|
||||
# Validate total matches calculated total
|
||||
if round_money(calculated_total) != db_expense.total_amount:
|
||||
raise InvalidOperationError(
|
||||
f"Expense total amount ({db_expense.total_amount}) does not match the "
|
||||
f"calculated total from item prices ({calculated_total})."
|
||||
)
|
||||
|
||||
# Create splits based on aggregated amounts
|
||||
splits = []
|
||||
for user_id, owed_amount in user_owed_amounts.items():
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=user_id,
|
||||
owed_amount=round_money(owed_amount)
|
||||
))
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
|
||||
"""Validates that all users in the splits exist."""
|
||||
|
||||
user_ids_in_split = [s.user_id for s in splits_in]
|
||||
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
|
||||
found_user_ids = {row[0] for row in user_results}
|
||||
|
||||
if len(found_user_ids) != len(user_ids_in_split):
|
||||
missing_user_ids = set(user_ids_in_split) - found_user_ids
|
||||
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
|
||||
|
||||
|
||||
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
|
||||
result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.options(
|
||||
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
||||
selectinload(ExpenseModel.paid_by_user),
|
||||
selectinload(ExpenseModel.list),
|
||||
selectinload(ExpenseModel.group),
|
||||
selectinload(ExpenseModel.item)
|
||||
)
|
||||
.where(ExpenseModel.id == expense_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
||||
result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.list_id == list_id)
|
||||
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
||||
result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.group_id == group_id)
|
||||
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
|
||||
"""
|
||||
Updates an existing expense.
|
||||
Only allows updates to description, currency, and expense_date to avoid split complexities.
|
||||
Requires version matching for optimistic locking.
|
||||
"""
|
||||
if expense_db.version != expense_in.version:
|
||||
raise InvalidOperationError(
|
||||
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
|
||||
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
|
||||
)
|
||||
|
||||
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
|
||||
|
||||
# Fields that are safe to update without affecting splits or core logic
|
||||
allowed_to_update = {"description", "currency", "expense_date"}
|
||||
|
||||
updated_something = False
|
||||
for field, value in update_data.items():
|
||||
if field in allowed_to_update:
|
||||
setattr(expense_db, field, value)
|
||||
updated_something = True
|
||||
else:
|
||||
# If any other field is present in the update payload, it's an invalid operation for this simple update
|
||||
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
|
||||
|
||||
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
|
||||
# No actual updatable fields were provided in the payload, even if others (like version) were.
|
||||
# This could be a non-issue, or an indication of a misuse of the endpoint.
|
||||
# For now, if only version was sent, we still increment if it matched.
|
||||
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
||||
|
||||
expense_db.version += 1
|
||||
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(expense_db)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
# Consider specific DB error types if needed
|
||||
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
|
||||
|
||||
return expense_db
|
||||
|
||||
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
||||
"""
|
||||
Deletes an expense. Requires version matching if expected_version is provided.
|
||||
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
|
||||
"""
|
||||
if expected_version is not None and expense_db.version != expected_version:
|
||||
raise InvalidOperationError(
|
||||
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
|
||||
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
await db.delete(expense_db)
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
|
||||
return None
|
||||
|
||||
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
||||
# For API endpoints, these should be translated to appropriate HTTPExceptions.
|
||||
# Ensure app.core.exceptions has proper HTTP error classes if needed.
|
@ -14,7 +14,9 @@ from app.core.exceptions import (
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError
|
||||
DatabaseTransactionError,
|
||||
GroupMembershipError,
|
||||
GroupPermissionError # Import GroupPermissionError
|
||||
)
|
||||
|
||||
# --- Group CRUD ---
|
||||
@ -153,3 +155,74 @@ async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
|
||||
|
||||
async def check_group_membership(
|
||||
db: AsyncSession,
|
||||
group_id: int,
|
||||
user_id: int,
|
||||
action: str = "access this group"
|
||||
) -> None:
|
||||
"""
|
||||
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
|
||||
|
||||
Raises:
|
||||
GroupNotFoundError: If the group_id does not exist.
|
||||
GroupMembershipError: If the user_id is not a member of the group.
|
||||
"""
|
||||
try:
|
||||
async with db.begin():
|
||||
# Check group existence first
|
||||
group_exists = await db.get(GroupModel, group_id)
|
||||
if not group_exists:
|
||||
raise GroupNotFoundError(group_id)
|
||||
|
||||
# Check membership
|
||||
membership = await db.execute(
|
||||
select(UserGroupModel.id)
|
||||
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||
.limit(1)
|
||||
)
|
||||
if membership.scalar_one_or_none() is None:
|
||||
raise GroupMembershipError(group_id, action=action)
|
||||
# If we reach here, the user is a member
|
||||
return None
|
||||
except GroupNotFoundError: # Re-raise specific errors
|
||||
raise
|
||||
except GroupMembershipError:
|
||||
raise
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
|
||||
|
||||
async def check_user_role_in_group(
|
||||
db: AsyncSession,
|
||||
group_id: int,
|
||||
user_id: int,
|
||||
required_role: UserRoleEnum,
|
||||
action: str = "perform this action"
|
||||
) -> None:
|
||||
"""
|
||||
Checks if a user is a member of a group and has the required role (or higher).
|
||||
|
||||
Raises:
|
||||
GroupNotFoundError: If the group_id does not exist.
|
||||
GroupMembershipError: If the user_id is not a member of the group.
|
||||
GroupPermissionError: If the user does not have the required role.
|
||||
"""
|
||||
# First, ensure user is a member (this also checks group existence)
|
||||
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
|
||||
|
||||
# Get the user's actual role
|
||||
actual_role = await get_user_role_in_group(db, group_id, user_id)
|
||||
|
||||
# Define role hierarchy (assuming owner > member)
|
||||
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
|
||||
|
||||
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
|
||||
raise GroupPermissionError(
|
||||
group_id=group_id,
|
||||
action=f"{action} (requires at least '{required_role.value}' role)"
|
||||
)
|
||||
# If role is sufficient, return None
|
||||
return None
|
168
be/app/crud/settlement.py
Normal file
168
be/app/crud/settlement.py
Normal file
@ -0,0 +1,168 @@
|
||||
# app/crud/settlement.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from sqlalchemy import or_
|
||||
from decimal import Decimal
|
||||
from typing import List as PyList, Optional, Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.models import (
|
||||
Settlement as SettlementModel,
|
||||
User as UserModel,
|
||||
Group as GroupModel
|
||||
)
|
||||
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
|
||||
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
|
||||
|
||||
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
||||
"""Creates a new settlement record."""
|
||||
|
||||
# Validate Payer, Payee, and Group exist
|
||||
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
||||
|
||||
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
|
||||
if not payee:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
|
||||
|
||||
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
|
||||
raise InvalidOperationError("Payer and Payee cannot be the same user.")
|
||||
|
||||
group = await db.get(GroupModel, settlement_in.group_id)
|
||||
if not group:
|
||||
raise GroupNotFoundError(settlement_in.group_id)
|
||||
|
||||
# Optional: Check if current_user_id is part of the group or is one of the parties involved
|
||||
# This is more of an API-level permission check but could be added here if strict.
|
||||
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
|
||||
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
|
||||
# if not is_in_group.first():
|
||||
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
|
||||
|
||||
db_settlement = SettlementModel(
|
||||
group_id=settlement_in.group_id,
|
||||
paid_by_user_id=settlement_in.paid_by_user_id,
|
||||
paid_to_user_id=settlement_in.paid_to_user_id,
|
||||
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
|
||||
description=settlement_in.description
|
||||
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
|
||||
)
|
||||
db.add(db_settlement)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
|
||||
|
||||
return db_settlement
|
||||
|
||||
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
.where(SettlementModel.id == settlement_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.group_id == group_id)
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_settlements_involving_user(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> Sequence[SettlementModel]:
|
||||
query = (
|
||||
select(SettlementModel)
|
||||
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
|
||||
)
|
||||
if group_id:
|
||||
query = query.where(SettlementModel.group_id == group_id)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
|
||||
"""
|
||||
Updates an existing settlement.
|
||||
Only allows updates to description and settlement_date.
|
||||
Requires version matching for optimistic locking.
|
||||
Assumes SettlementUpdate schema includes a version field.
|
||||
"""
|
||||
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
|
||||
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
|
||||
raise InvalidOperationError(
|
||||
f"Settlement (ID: {settlement_db.id}) has been modified. "
|
||||
f"Your version does not match current version {settlement_db.version}. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
|
||||
allowed_to_update = {"description", "settlement_date"}
|
||||
updated_something = False
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field in allowed_to_update:
|
||||
setattr(settlement_db, field, value)
|
||||
updated_something = True
|
||||
else:
|
||||
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
|
||||
|
||||
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
|
||||
pass # No actual updatable fields provided, but version matched.
|
||||
|
||||
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
|
||||
settlement_db.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(settlement_db)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
|
||||
|
||||
return settlement_db
|
||||
|
||||
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
|
||||
"""
|
||||
Deletes a settlement. Requires version matching if expected_version is provided.
|
||||
Assumes SettlementModel has a version field.
|
||||
"""
|
||||
if expected_version is not None:
|
||||
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
||||
raise InvalidOperationError(
|
||||
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
|
||||
f"Expected version {expected_version} does not match current version. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
await db.delete(settlement_db)
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
|
||||
return None
|
||||
|
||||
# TODO: Implement update_settlement (consider restrictions, versioning)
|
||||
# TODO: Implement delete_settlement (consider implications on balances)
|
118
be/app/models.py
118
be/app/models.py
@ -21,7 +21,7 @@ from sqlalchemy import (
|
||||
Text, # <-- Add Text for description
|
||||
Numeric # <-- Add Numeric for price
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, backref
|
||||
|
||||
from .database import Base
|
||||
|
||||
@ -30,6 +30,14 @@ class UserRoleEnum(enum.Enum):
|
||||
owner = "owner"
|
||||
member = "member"
|
||||
|
||||
class SplitTypeEnum(enum.Enum):
|
||||
EQUAL = "EQUAL" # Split equally among all involved users
|
||||
EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit)
|
||||
PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit)
|
||||
SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit)
|
||||
ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them
|
||||
# Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense)
|
||||
|
||||
# --- User Model ---
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
@ -51,6 +59,13 @@ class User(Base):
|
||||
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
|
||||
# --- End NEW Relationships ---
|
||||
|
||||
# --- Relationships for Cost Splitting ---
|
||||
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
|
||||
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
|
||||
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
|
||||
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
|
||||
# --- End Relationships for Cost Splitting ---
|
||||
|
||||
|
||||
# --- Group Model ---
|
||||
class Group(Base):
|
||||
@ -70,6 +85,11 @@ class Group(Base):
|
||||
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
|
||||
# --- End NEW Relationship ---
|
||||
|
||||
# --- Relationships for Cost Splitting ---
|
||||
expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan")
|
||||
settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan")
|
||||
# --- End Relationships for Cost Splitting ---
|
||||
|
||||
|
||||
# --- UserGroup Association Model ---
|
||||
class UserGroup(Base):
|
||||
@ -124,6 +144,10 @@ class List(Base):
|
||||
group = relationship("Group", back_populates="lists") # Link to Group.lists
|
||||
items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes
|
||||
|
||||
# --- Relationships for Cost Splitting ---
|
||||
expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan")
|
||||
# --- End Relationships for Cost Splitting ---
|
||||
|
||||
|
||||
# === NEW: Item Model ===
|
||||
class Item(Base):
|
||||
@ -145,3 +169,95 @@ class Item(Base):
|
||||
list = relationship("List", back_populates="items") # Link to List.items
|
||||
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
|
||||
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
|
||||
|
||||
# --- Relationships for Cost Splitting ---
|
||||
# If an item directly results in an expense, or an expense can be tied to an item.
|
||||
expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses
|
||||
# --- End Relationships for Cost Splitting ---
|
||||
|
||||
|
||||
# === NEW Models for Advanced Cost Splitting ===
|
||||
|
||||
class Expense(Base):
|
||||
__tablename__ = "expenses"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
description = Column(String, nullable=False)
|
||||
total_amount = Column(Numeric(10, 2), nullable=False)
|
||||
currency = Column(String, nullable=False, default="USD") # Consider making this an Enum too if few currencies
|
||||
expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False)
|
||||
|
||||
# Foreign Keys
|
||||
list_id = Column(Integer, ForeignKey("lists.id"), nullable=True)
|
||||
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # If not list-specific but group-specific
|
||||
item_id = Column(Integer, ForeignKey("items.id"), nullable=True) # If the expense is for a specific item
|
||||
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||
|
||||
# Relationships
|
||||
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
|
||||
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
|
||||
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
|
||||
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
|
||||
splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
# Example: Ensure either list_id or group_id is present if item_id is null
|
||||
# CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
|
||||
)
|
||||
|
||||
class ExpenseSplit(Base):
|
||||
__tablename__ = "expense_splits"
|
||||
__table_args__ = (UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),)
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
owed_amount = Column(Numeric(10, 2), nullable=False) # For EQUAL or EXACT_AMOUNTS
|
||||
# For PERCENTAGE split (value from 0.00 to 100.00)
|
||||
share_percentage = Column(Numeric(5, 2), nullable=True)
|
||||
# For SHARES split (e.g., user A has 2 shares, user B has 3 shares)
|
||||
share_units = Column(Integer, nullable=True)
|
||||
|
||||
# is_settled might be better tracked via actual Settlement records or a reconciliation process
|
||||
# is_settled = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
expense = relationship("Expense", back_populates="splits")
|
||||
user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits")
|
||||
|
||||
|
||||
class Settlement(Base):
|
||||
__tablename__ = "settlements"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
group_id = Column(Integer, ForeignKey("groups.id"), nullable=False) # Settlements usually within a group
|
||||
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
amount = Column(Numeric(10, 2), nullable=False)
|
||||
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
version = Column(Integer, nullable=False, default=1, server_default='1')
|
||||
|
||||
# Relationships
|
||||
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
|
||||
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
|
||||
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
|
||||
|
||||
__table_args__ = (
|
||||
# Ensure payer and payee are different users
|
||||
# CheckConstraint('paid_by_user_id <> paid_to_user_id', name='chk_settlement_payer_ne_payee'),
|
||||
)
|
||||
|
||||
# Potential future: PaymentMethod model, etc.
|
@ -20,3 +20,36 @@ class ListCostSummary(BaseModel):
|
||||
user_balances: List[UserCostShare]
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class UserBalanceDetail(BaseModel):
|
||||
user_id: int
|
||||
user_identifier: str # Name or email
|
||||
total_paid_for_expenses: Decimal = Decimal("0.00")
|
||||
total_share_of_expenses: Decimal = Decimal("0.00")
|
||||
total_settlements_paid: Decimal = Decimal("0.00")
|
||||
total_settlements_received: Decimal = Decimal("0.00")
|
||||
net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class SuggestedSettlement(BaseModel):
|
||||
from_user_id: int
|
||||
from_user_identifier: str # Name or email of payer
|
||||
to_user_id: int
|
||||
to_user_identifier: str # Name or email of payee
|
||||
amount: Decimal
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class GroupBalanceSummary(BaseModel):
|
||||
group_id: int
|
||||
group_name: str
|
||||
overall_total_expenses: Decimal = Decimal("0.00")
|
||||
overall_total_settlements: Decimal = Decimal("0.00")
|
||||
user_balances: List[UserBalanceDetail]
|
||||
# Optional: Could add a list of suggested settlements to zero out balances
|
||||
suggested_settlements: Optional[List[SuggestedSettlement]] = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# class SuggestedSettlement(BaseModel):
|
||||
# from_user_id: int
|
||||
# to_user_id: int
|
||||
# amount: Decimal
|
131
be/app/schemas/expense.py
Normal file
131
be/app/schemas/expense.py
Normal file
@ -0,0 +1,131 @@
|
||||
# app/schemas/expense.py
|
||||
from pydantic import BaseModel, ConfigDict, validator
|
||||
from typing import List, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums
|
||||
# For now, let's redefine it or import it if models.py is parsable by Pydantic directly
|
||||
# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it.
|
||||
# For simplicity during schema definition, I'll redefine a string enum here.
|
||||
# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py.
|
||||
from app.models import SplitTypeEnum # Try importing directly
|
||||
|
||||
# --- ExpenseSplit Schemas ---
|
||||
class ExpenseSplitBase(BaseModel):
|
||||
user_id: int
|
||||
owed_amount: Decimal
|
||||
share_percentage: Optional[Decimal] = None
|
||||
share_units: Optional[int] = None
|
||||
|
||||
class ExpenseSplitCreate(ExpenseSplitBase):
|
||||
pass # All fields from base are needed for creation
|
||||
|
||||
class ExpenseSplitPublic(ExpenseSplitBase):
|
||||
id: int
|
||||
expense_id: int
|
||||
# user: Optional[UserPublic] # If we want to nest user details
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# --- Expense Schemas ---
|
||||
class ExpenseBase(BaseModel):
|
||||
description: str
|
||||
total_amount: Decimal
|
||||
currency: Optional[str] = "USD"
|
||||
expense_date: Optional[datetime] = None
|
||||
split_type: SplitTypeEnum
|
||||
list_id: Optional[int] = None
|
||||
group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa
|
||||
item_id: Optional[int] = None
|
||||
paid_by_user_id: int
|
||||
|
||||
class ExpenseCreate(ExpenseBase):
|
||||
# For EQUAL split, splits are generated. For others, they might be provided.
|
||||
# This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES,
|
||||
# then 'splits_in' should be provided.
|
||||
splits_in: Optional[List[ExpenseSplitCreate]] = None
|
||||
|
||||
@validator('total_amount')
|
||||
def total_amount_must_be_positive(cls, v):
|
||||
if v <= Decimal('0'):
|
||||
raise ValueError('Total amount must be positive')
|
||||
return v
|
||||
|
||||
# Basic validation: if list_id is None, group_id must be provided.
|
||||
# More complex cross-field validation might be needed.
|
||||
@validator('group_id', always=True)
|
||||
def check_list_or_group_id(cls, v, values):
|
||||
if values.get('list_id') is None and v is None:
|
||||
raise ValueError('Either list_id or group_id must be provided for an expense')
|
||||
return v
|
||||
|
||||
class ExpenseUpdate(BaseModel):
|
||||
description: Optional[str] = None
|
||||
total_amount: Optional[Decimal] = None
|
||||
currency: Optional[str] = None
|
||||
expense_date: Optional[datetime] = None
|
||||
split_type: Optional[SplitTypeEnum] = None
|
||||
list_id: Optional[int] = None
|
||||
group_id: Optional[int] = None
|
||||
item_id: Optional[int] = None
|
||||
# paid_by_user_id is usually not updatable directly to maintain integrity.
|
||||
# Updating splits would be a more complex operation, potentially a separate endpoint or careful logic.
|
||||
version: int # For optimistic locking
|
||||
|
||||
class ExpensePublic(ExpenseBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
version: int
|
||||
splits: List[ExpenseSplitPublic] = []
|
||||
# paid_by_user: Optional[UserPublic] # If nesting user details
|
||||
# list: Optional[ListPublic] # If nesting list details
|
||||
# group: Optional[GroupPublic] # If nesting group details
|
||||
# item: Optional[ItemPublic] # If nesting item details
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# --- Settlement Schemas ---
|
||||
class SettlementBase(BaseModel):
|
||||
group_id: int
|
||||
paid_by_user_id: int
|
||||
paid_to_user_id: int
|
||||
amount: Decimal
|
||||
settlement_date: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class SettlementCreate(SettlementBase):
|
||||
@validator('amount')
|
||||
def amount_must_be_positive(cls, v):
|
||||
if v <= Decimal('0'):
|
||||
raise ValueError('Settlement amount must be positive')
|
||||
return v
|
||||
|
||||
@validator('paid_to_user_id')
|
||||
def payer_and_payee_must_be_different(cls, v, values):
|
||||
if 'paid_by_user_id' in values and v == values['paid_by_user_id']:
|
||||
raise ValueError('Payer and payee cannot be the same user')
|
||||
return v
|
||||
|
||||
class SettlementUpdate(BaseModel):
|
||||
amount: Optional[Decimal] = None
|
||||
settlement_date: Optional[datetime] = None
|
||||
description: Optional[str] = None
|
||||
# group_id, paid_by_user_id, paid_to_user_id are typically not updatable.
|
||||
version: int # For optimistic locking
|
||||
|
||||
class SettlementPublic(SettlementBase):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
# payer: Optional[UserPublic]
|
||||
# payee: Optional[UserPublic]
|
||||
# group: Optional[GroupPublic]
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
# Placeholder for nested schemas (e.g., UserPublic) if needed
|
||||
# from app.schemas.user import UserPublic
|
||||
# from app.schemas.list import ListPublic
|
||||
# from app.schemas.group import GroupPublic
|
||||
# from app.schemas.item import ItemPublic
|
373
be/tests/api/v1/endpoints/test_financials.py
Normal file
373
be/tests/api/v1/endpoints/test_financials.py
Normal file
@ -0,0 +1,373 @@
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Callable, Dict, Any
|
||||
|
||||
from app.models import User as UserModel, Group as GroupModel, List as ListModel
|
||||
from app.schemas.expense import ExpenseCreate
|
||||
from app.core.config import settings
|
||||
|
||||
# Helper to create a URL for an endpoint
|
||||
API_V1_STR = settings.API_V1_STR
|
||||
|
||||
def expense_url(endpoint: str = "") -> str:
|
||||
return f"{API_V1_STR}/financials/expenses{endpoint}"
|
||||
|
||||
def settlement_url(endpoint: str = "") -> str:
|
||||
return f"{API_V1_STR}/financials/settlements{endpoint}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_expense_success_list_context(
|
||||
client: AsyncClient,
|
||||
db_session: AsyncSession, # Assuming a fixture for db session
|
||||
normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth
|
||||
test_user: UserModel, # Assuming a fixture for a test user
|
||||
test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of
|
||||
) -> None:
|
||||
"""
|
||||
Test successful creation of a new expense linked to a list.
|
||||
"""
|
||||
expense_data = ExpenseCreate(
|
||||
description="Test Expense for List",
|
||||
amount=100.00,
|
||||
currency="USD",
|
||||
paid_by_user_id=test_user.id,
|
||||
list_id=test_list_user_is_member.id,
|
||||
group_id=None, # group_id should be derived from list if list is in a group
|
||||
# category_id: Optional[int] = None # Assuming category is optional
|
||||
# expense_date: Optional[date] = None # Assuming date is optional
|
||||
# splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
expense_url(),
|
||||
headers=normal_user_token_headers,
|
||||
json=expense_data.model_dump(exclude_unset=True)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
content = response.json()
|
||||
assert content["description"] == expense_data.description
|
||||
assert content["amount"] == expense_data.amount
|
||||
assert content["currency"] == expense_data.currency
|
||||
assert content["paid_by_user_id"] == test_user.id
|
||||
assert content["list_id"] == test_list_user_is_member.id
|
||||
# If test_list_user_is_member has a group_id, it should be set in the response
|
||||
if test_list_user_is_member.group_id:
|
||||
assert content["group_id"] == test_list_user_is_member.group_id
|
||||
else:
|
||||
assert content["group_id"] is None
|
||||
assert "id" in content
|
||||
assert "created_at" in content
|
||||
assert "updated_at" in content
|
||||
assert "version" in content
|
||||
assert content["version"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_expense_success_group_context(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_user: UserModel,
|
||||
test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of
|
||||
) -> None:
|
||||
"""
|
||||
Test successful creation of a new expense linked directly to a group.
|
||||
"""
|
||||
expense_data = ExpenseCreate(
|
||||
description="Test Expense for Group",
|
||||
amount=50.00,
|
||||
currency="EUR",
|
||||
paid_by_user_id=test_user.id,
|
||||
group_id=test_group_user_is_member.id,
|
||||
list_id=None,
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
expense_url(),
|
||||
headers=normal_user_token_headers,
|
||||
json=expense_data.model_dump(exclude_unset=True)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
content = response.json()
|
||||
assert content["description"] == expense_data.description
|
||||
assert content["paid_by_user_id"] == test_user.id
|
||||
assert content["group_id"] == test_group_user_is_member.id
|
||||
assert content["list_id"] is None
|
||||
assert content["version"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_expense_fail_no_list_or_group(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_user: UserModel,
|
||||
) -> None:
|
||||
"""
|
||||
Test expense creation fails if neither list_id nor group_id is provided.
|
||||
"""
|
||||
expense_data = ExpenseCreate(
|
||||
description="Test Invalid Expense",
|
||||
amount=10.00,
|
||||
currency="USD",
|
||||
paid_by_user_id=test_user.id,
|
||||
list_id=None,
|
||||
group_id=None,
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
expense_url(),
|
||||
headers=normal_user_token_headers,
|
||||
json=expense_data.model_dump(exclude_unset=True)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
content = response.json()
|
||||
assert "Expense must be linked to a list_id or group_id" in content["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_expense_fail_paid_by_other_not_owner(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str], # User is member, not owner
|
||||
test_user: UserModel, # This is the current_user (member)
|
||||
test_group_user_is_member: GroupModel, # Group the current_user is a member of
|
||||
another_user_in_group: UserModel, # Another user in the same group
|
||||
# Ensure test_user is NOT an owner of test_group_user_is_member for this test
|
||||
) -> None:
|
||||
"""
|
||||
Test creation fails if paid_by_user_id is another user, and current_user is not a group owner.
|
||||
Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member.
|
||||
"""
|
||||
expense_data = ExpenseCreate(
|
||||
description="Expense paid by other",
|
||||
amount=75.00,
|
||||
currency="GBP",
|
||||
paid_by_user_id=another_user_in_group.id, # Paid by someone else
|
||||
group_id=test_group_user_is_member.id,
|
||||
list_id=None,
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
expense_url(),
|
||||
headers=normal_user_token_headers, # Current user is a member, not owner
|
||||
json=expense_data.model_dump(exclude_unset=True)
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
content = response.json()
|
||||
assert "Only group owners can create expenses paid by others" in content["detail"]
|
||||
|
||||
# --- Add tests for other endpoints below ---
|
||||
# GET /expenses/{expense_id}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expense_success(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_user: UserModel,
|
||||
# Assume an existing expense created by test_user or in a group/list they have access to
|
||||
# This would typically be created by another test or a fixture
|
||||
created_expense: ExpensePublic, # Assuming a fixture that provides a created expense
|
||||
) -> None:
|
||||
"""
|
||||
Test successfully retrieving an existing expense.
|
||||
User has access either by being the payer, or via list/group membership.
|
||||
"""
|
||||
response = await client.get(
|
||||
expense_url(f"/{created_expense.id}"),
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
content = response.json()
|
||||
assert content["id"] == created_expense.id
|
||||
assert content["description"] == created_expense.description
|
||||
assert content["amount"] == created_expense.amount
|
||||
assert content["paid_by_user_id"] == created_expense.paid_by_user_id
|
||||
if created_expense.list_id:
|
||||
assert content["list_id"] == created_expense.list_id
|
||||
if created_expense.group_id:
|
||||
assert content["group_id"] == created_expense.group_id
|
||||
|
||||
# TODO: Add more tests for get_expense:
|
||||
# - expense not found -> 404
|
||||
# - user has no access (not payer, not in list, not in group if applicable) -> 403
|
||||
# - expense in list, user has list access
|
||||
# - expense in group, user has group access
|
||||
# - expense personal (no list, no group), user is payer
|
||||
# - expense personal (no list, no group), user is NOT payer -> 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expense_not_found(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Test retrieving a non-existent expense results in 404.
|
||||
"""
|
||||
non_existent_expense_id = 9999999
|
||||
response = await client.get(
|
||||
expense_url(f"/{non_existent_expense_id}"),
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
content = response.json()
|
||||
assert "not found" in content["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expense_forbidden_personal_expense_other_user(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str], # Belongs to test_user
|
||||
# Fixture for an expense paid by another_user, not linked to any list/group test_user has access to
|
||||
personal_expense_of_another_user: ExpensePublic
|
||||
) -> None:
|
||||
"""
|
||||
Test retrieving a personal expense of another user (no shared list/group) results in 403.
|
||||
"""
|
||||
response = await client.get(
|
||||
expense_url(f"/{personal_expense_of_another_user.id}"),
|
||||
headers=normal_user_token_headers # Current user querying
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
content = response.json()
|
||||
assert "Not authorized to view this expense" in content["detail"]
|
||||
|
||||
# GET /lists/{list_id}/expenses
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_list_expenses_success(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_user: UserModel,
|
||||
test_list_user_is_member: ListModel, # List the user is a member of
|
||||
# Assume some expenses have been created for this list by a fixture or previous tests
|
||||
) -> None:
|
||||
"""
|
||||
Test successfully listing expenses for a list the user has access to.
|
||||
"""
|
||||
response = await client.get(
|
||||
f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses",
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
content = response.json()
|
||||
assert isinstance(content, list)
|
||||
for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense
|
||||
assert expense_item["list_id"] == test_list_user_is_member.id
|
||||
|
||||
# TODO: Add more tests for list_list_expenses:
|
||||
# - list not found -> 404 (ListNotFoundError from check_list_access_for_financials)
|
||||
# - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials)
|
||||
# - list exists but has no expenses -> empty list, 200 OK
|
||||
# - test pagination (skip, limit)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_list_expenses_list_not_found(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check).
|
||||
The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404.
|
||||
The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here).
|
||||
Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials.
|
||||
This should translate to a 404 or a 403 if ListPermissionError wraps it with an action.
|
||||
The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause.
|
||||
ListNotFoundError is a custom exception often mapped to 404.
|
||||
Let's assume ListNotFoundError results in a 404 response from an exception handler.
|
||||
"""
|
||||
non_existent_list_id = 99999
|
||||
response = await client.get(
|
||||
f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses",
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
# The ListNotFoundError is raised by the check_list_access_for_financials helper,
|
||||
# which is then re-raised. FastAPI default exception handlers or custom ones
|
||||
# would convert this to an HTTP response. Typically NotFoundError -> 404.
|
||||
# If ListPermissionError catches it and re-raises it specifically, it might be 403.
|
||||
# From the code: `except ListNotFoundError: raise` means it propagates.
|
||||
# Let's assume a global handler for NotFoundError derived exceptions leads to 404.
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# The actual detail might vary based on how ListNotFoundError is handled by FastAPI
|
||||
# For now, we check the status code. If financials.py maps it differently, this will need adjustment.
|
||||
# Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400,
|
||||
# this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError.
|
||||
# Let's stick to expecting 404 for a direct not found error from a path parameter.
|
||||
content = response.json()
|
||||
assert "list not found" in content["detail"].lower() # Common detail for not found errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_list_expenses_no_access(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str], # User who will attempt access
|
||||
test_list_user_not_member: ListModel, # A list current user is NOT a member of
|
||||
) -> None:
|
||||
"""
|
||||
Test listing expenses for a list the user does not have access to (403 Forbidden).
|
||||
"""
|
||||
response = await client.get(
|
||||
f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses",
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
content = response.json()
|
||||
assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_list_expenses_empty(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses
|
||||
) -> None:
|
||||
"""
|
||||
Test listing expenses for an accessible list that has no expenses (empty list, 200 OK).
|
||||
"""
|
||||
response = await client.get(
|
||||
f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses",
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
content = response.json()
|
||||
assert isinstance(content, list)
|
||||
assert len(content) == 0
|
||||
|
||||
# GET /groups/{group_id}/expenses
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_group_expenses_success(
|
||||
client: AsyncClient,
|
||||
normal_user_token_headers: Dict[str, str],
|
||||
test_user: UserModel,
|
||||
test_group_user_is_member: GroupModel, # Group the user is a member of
|
||||
# Assume some expenses have been created for this group by a fixture or previous tests
|
||||
) -> None:
|
||||
"""
|
||||
Test successfully listing expenses for a group the user has access to.
|
||||
"""
|
||||
response = await client.get(
|
||||
f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses",
|
||||
headers=normal_user_token_headers
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
content = response.json()
|
||||
assert isinstance(content, list)
|
||||
# Further assertions can be made here, e.g., checking if all expenses belong to the group
|
||||
for expense_item in content:
|
||||
assert expense_item["group_id"] == test_group_user_is_member.id
|
||||
# Expenses in a group might also have a list_id if they were added via a list belonging to that group
|
||||
|
||||
# TODO: Add more tests for list_group_expenses:
|
||||
# - group not found -> 404 (GroupNotFoundError from check_group_membership)
|
||||
# - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership)
|
||||
# - group exists but has no expenses -> empty list, 200 OK
|
||||
# - test pagination (skip, limit)
|
||||
|
||||
# PUT /expenses/{expense_id}
|
||||
# DELETE /expenses/{expense_id}
|
||||
|
||||
# GET /settlements/{settlement_id}
|
||||
# POST /settlements
|
||||
# GET /groups/{group_id}/settlements
|
||||
# PUT /settlements/{settlement_id}
|
||||
# DELETE /settlements/{settlement_id}
|
||||
|
||||
pytest.skip("Still implementing other tests", allow_module_level=True)
|
1
be/tests/core/__init__.py
Normal file
1
be/tests/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
345
be/tests/core/test_exceptions.py
Normal file
345
be/tests/core/test_exceptions.py
Normal file
@ -0,0 +1,345 @@
|
||||
import pytest
|
||||
from fastapi import status
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
ListPermissionError,
|
||||
ListCreatorRequiredError,
|
||||
GroupNotFoundError,
|
||||
GroupPermissionError,
|
||||
GroupMembershipError,
|
||||
GroupOperationError,
|
||||
GroupValidationError,
|
||||
ItemNotFoundError,
|
||||
UserNotFoundError,
|
||||
InvalidOperationError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseTransactionError,
|
||||
DatabaseQueryError,
|
||||
OCRServiceUnavailableError,
|
||||
OCRServiceConfigError,
|
||||
OCRUnexpectedError,
|
||||
OCRQuotaExceededError,
|
||||
InvalidFileTypeError,
|
||||
FileTooLargeError,
|
||||
OCRProcessingError,
|
||||
EmailAlreadyRegisteredError,
|
||||
UserCreationError,
|
||||
InviteNotFoundError,
|
||||
InviteExpiredError,
|
||||
InviteAlreadyUsedError,
|
||||
InviteCreationError,
|
||||
ListStatusNotFoundError,
|
||||
ConflictError,
|
||||
InvalidCredentialsError,
|
||||
NotAuthenticatedError,
|
||||
JWTError,
|
||||
JWTUnexpectedError
|
||||
)
|
||||
# TODO: It seems like settings are used in some exceptions.
|
||||
# You will need to mock app.config.settings for these tests to pass.
|
||||
# Consider using pytest-mock or unittest.mock.patch.
|
||||
# Example: from app.config import settings
|
||||
|
||||
|
||||
def test_list_not_found_error():
|
||||
list_id = 123
|
||||
with pytest.raises(ListNotFoundError) as excinfo:
|
||||
raise ListNotFoundError(list_id=list_id)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"List {list_id} not found"
|
||||
|
||||
def test_list_permission_error():
|
||||
list_id = 456
|
||||
action = "delete"
|
||||
with pytest.raises(ListPermissionError) as excinfo:
|
||||
raise ListPermissionError(list_id=list_id, action=action)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}"
|
||||
|
||||
def test_list_permission_error_default_action():
|
||||
list_id = 789
|
||||
with pytest.raises(ListPermissionError) as excinfo:
|
||||
raise ListPermissionError(list_id=list_id)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"You do not have permission to access list {list_id}"
|
||||
|
||||
def test_list_creator_required_error():
|
||||
list_id = 101
|
||||
action = "update"
|
||||
with pytest.raises(ListCreatorRequiredError) as excinfo:
|
||||
raise ListCreatorRequiredError(list_id=list_id, action=action)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}"
|
||||
|
||||
def test_group_not_found_error():
|
||||
group_id = 202
|
||||
with pytest.raises(GroupNotFoundError) as excinfo:
|
||||
raise GroupNotFoundError(group_id=group_id)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"Group {group_id} not found"
|
||||
|
||||
def test_group_permission_error():
|
||||
group_id = 303
|
||||
action = "invite"
|
||||
with pytest.raises(GroupPermissionError) as excinfo:
|
||||
raise GroupPermissionError(group_id=group_id, action=action)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}"
|
||||
|
||||
def test_group_membership_error():
|
||||
group_id = 404
|
||||
action = "post"
|
||||
with pytest.raises(GroupMembershipError) as excinfo:
|
||||
raise GroupMembershipError(group_id=group_id, action=action)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}"
|
||||
|
||||
def test_group_membership_error_default_action():
|
||||
group_id = 505
|
||||
with pytest.raises(GroupMembershipError) as excinfo:
|
||||
raise GroupMembershipError(group_id=group_id)
|
||||
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert excinfo.value.detail == f"You must be a member of group {group_id} to access"
|
||||
|
||||
def test_group_operation_error():
|
||||
detail_msg = "Failed to perform group operation."
|
||||
with pytest.raises(GroupOperationError) as excinfo:
|
||||
raise GroupOperationError(detail=detail_msg)
|
||||
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert excinfo.value.detail == detail_msg
|
||||
|
||||
def test_group_validation_error():
|
||||
detail_msg = "Invalid group data."
|
||||
with pytest.raises(GroupValidationError) as excinfo:
|
||||
raise GroupValidationError(detail=detail_msg)
|
||||
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert excinfo.value.detail == detail_msg
|
||||
|
||||
def test_item_not_found_error():
|
||||
item_id = 606
|
||||
with pytest.raises(ItemNotFoundError) as excinfo:
|
||||
raise ItemNotFoundError(item_id=item_id)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"Item {item_id} not found"
|
||||
|
||||
def test_user_not_found_error_no_identifier():
|
||||
with pytest.raises(UserNotFoundError) as excinfo:
|
||||
raise UserNotFoundError()
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == "User not found."
|
||||
|
||||
def test_user_not_found_error_with_id():
|
||||
user_id = 707
|
||||
with pytest.raises(UserNotFoundError) as excinfo:
|
||||
raise UserNotFoundError(user_id=user_id)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"User with ID {user_id} not found."
|
||||
|
||||
def test_user_not_found_error_with_identifier_string():
|
||||
identifier = "test_user"
|
||||
with pytest.raises(UserNotFoundError) as excinfo:
|
||||
raise UserNotFoundError(identifier=identifier)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"User with identifier '{identifier}' not found."
|
||||
|
||||
def test_invalid_operation_error():
|
||||
detail_msg = "This operation is not allowed."
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
raise InvalidOperationError(detail=detail_msg)
|
||||
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert excinfo.value.detail == detail_msg
|
||||
|
||||
def test_invalid_operation_error_custom_status():
|
||||
detail_msg = "This operation is forbidden."
|
||||
custom_status = status.HTTP_403_FORBIDDEN
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
raise InvalidOperationError(detail=detail_msg, status_code=custom_status)
|
||||
assert excinfo.value.status_code == custom_status
|
||||
assert excinfo.value.detail == detail_msg
|
||||
|
||||
# The following exceptions depend on `settings`
|
||||
# We need to mock `app.config.settings` for these tests.
|
||||
# For now, I will add placeholder tests that would fail without mocking.
|
||||
# Consider using pytest-mock or unittest.mock.patch for this.
|
||||
|
||||
# def test_database_connection_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.DB_CONNECTION_ERROR", "Test DB connection error")
|
||||
# with pytest.raises(DatabaseConnectionError) as excinfo:
|
||||
# raise DatabaseConnectionError()
|
||||
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
# assert excinfo.value.detail == "Test DB connection error" # settings.DB_CONNECTION_ERROR
|
||||
|
||||
# def test_database_integrity_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.DB_INTEGRITY_ERROR", "Test DB integrity error")
|
||||
# with pytest.raises(DatabaseIntegrityError) as excinfo:
|
||||
# raise DatabaseIntegrityError()
|
||||
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
# assert excinfo.value.detail == "Test DB integrity error" # settings.DB_INTEGRITY_ERROR
|
||||
|
||||
# def test_database_transaction_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.DB_TRANSACTION_ERROR", "Test DB transaction error")
|
||||
# with pytest.raises(DatabaseTransactionError) as excinfo:
|
||||
# raise DatabaseTransactionError()
|
||||
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
# assert excinfo.value.detail == "Test DB transaction error" # settings.DB_TRANSACTION_ERROR
|
||||
|
||||
# def test_database_query_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.DB_QUERY_ERROR", "Test DB query error")
|
||||
# with pytest.raises(DatabaseQueryError) as excinfo:
|
||||
# raise DatabaseQueryError()
|
||||
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
# assert excinfo.value.detail == "Test DB query error" # settings.DB_QUERY_ERROR
|
||||
|
||||
# def test_ocr_service_unavailable_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_UNAVAILABLE", "Test OCR unavailable")
|
||||
# with pytest.raises(OCRServiceUnavailableError) as excinfo:
|
||||
# raise OCRServiceUnavailableError()
|
||||
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
# assert excinfo.value.detail == "Test OCR unavailable" # settings.OCR_SERVICE_UNAVAILABLE
|
||||
|
||||
# def test_ocr_service_config_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_CONFIG_ERROR", "Test OCR config error")
|
||||
# with pytest.raises(OCRServiceConfigError) as excinfo:
|
||||
# raise OCRServiceConfigError()
|
||||
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
# assert excinfo.value.detail == "Test OCR config error" # settings.OCR_SERVICE_CONFIG_ERROR
|
||||
|
||||
# def test_ocr_unexpected_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_UNEXPECTED_ERROR", "Test OCR unexpected error")
|
||||
# with pytest.raises(OCRUnexpectedError) as excinfo:
|
||||
# raise OCRUnexpectedError()
|
||||
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
# assert excinfo.value.detail == "Test OCR unexpected error" # settings.OCR_UNEXPECTED_ERROR
|
||||
|
||||
# def test_ocr_quota_exceeded_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_QUOTA_EXCEEDED", "Test OCR quota exceeded")
|
||||
# with pytest.raises(OCRQuotaExceededError) as excinfo:
|
||||
# raise OCRQuotaExceededError()
|
||||
# assert excinfo.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
# assert excinfo.value.detail == "Test OCR quota exceeded" # settings.OCR_QUOTA_EXCEEDED
|
||||
|
||||
# def test_invalid_file_type_error(mocker):
|
||||
# test_types = ["png", "jpg"]
|
||||
# mocker.patch("app.core.exceptions.settings.ALLOWED_IMAGE_TYPES", test_types)
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_INVALID_FILE_TYPE", "Invalid type: {types}")
|
||||
# with pytest.raises(InvalidFileTypeError) as excinfo:
|
||||
# raise InvalidFileTypeError()
|
||||
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
# assert excinfo.value.detail == f"Invalid type: {', '.join(test_types)}" # settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
|
||||
|
||||
# def test_file_too_large_error(mocker):
|
||||
# max_size = 10
|
||||
# mocker.patch("app.core.exceptions.settings.MAX_FILE_SIZE_MB", max_size)
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_FILE_TOO_LARGE", "File too large: {size}MB")
|
||||
# with pytest.raises(FileTooLargeError) as excinfo:
|
||||
# raise FileTooLargeError()
|
||||
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
# assert excinfo.value.detail == f"File too large: {max_size}MB" # settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
|
||||
|
||||
# def test_ocr_processing_error(mocker):
|
||||
# error_detail = "Specific OCR error"
|
||||
# mocker.patch("app.core.exceptions.settings.OCR_PROCESSING_ERROR", "OCR processing failed: {detail}")
|
||||
# with pytest.raises(OCRProcessingError) as excinfo:
|
||||
# raise OCRProcessingError(detail=error_detail)
|
||||
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
# assert excinfo.value.detail == f"OCR processing failed: {error_detail}" # settings.OCR_PROCESSING_ERROR.format(detail=detail)
|
||||
|
||||
|
||||
def test_email_already_registered_error():
|
||||
with pytest.raises(EmailAlreadyRegisteredError) as excinfo:
|
||||
raise EmailAlreadyRegisteredError()
|
||||
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
assert excinfo.value.detail == "Email already registered."
|
||||
|
||||
def test_user_creation_error():
|
||||
with pytest.raises(UserCreationError) as excinfo:
|
||||
raise UserCreationError()
|
||||
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert excinfo.value.detail == "An error occurred during user creation."
|
||||
|
||||
def test_invite_not_found_error():
|
||||
invite_code = "TESTCODE123"
|
||||
with pytest.raises(InviteNotFoundError) as excinfo:
|
||||
raise InviteNotFoundError(invite_code=invite_code)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"Invite code {invite_code} not found"
|
||||
|
||||
def test_invite_expired_error():
|
||||
invite_code = "EXPIREDCODE"
|
||||
with pytest.raises(InviteExpiredError) as excinfo:
|
||||
raise InviteExpiredError(invite_code=invite_code)
|
||||
assert excinfo.value.status_code == status.HTTP_410_GONE
|
||||
assert excinfo.value.detail == f"Invite code {invite_code} has expired"
|
||||
|
||||
def test_invite_already_used_error():
|
||||
invite_code = "USEDCODE"
|
||||
with pytest.raises(InviteAlreadyUsedError) as excinfo:
|
||||
raise InviteAlreadyUsedError(invite_code=invite_code)
|
||||
assert excinfo.value.status_code == status.HTTP_410_GONE
|
||||
assert excinfo.value.detail == f"Invite code {invite_code} has already been used"
|
||||
|
||||
def test_invite_creation_error():
|
||||
group_id = 909
|
||||
with pytest.raises(InviteCreationError) as excinfo:
|
||||
raise InviteCreationError(group_id=group_id)
|
||||
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert excinfo.value.detail == f"Failed to create invite for group {group_id}"
|
||||
|
||||
def test_list_status_not_found_error():
|
||||
list_id = 808
|
||||
with pytest.raises(ListStatusNotFoundError) as excinfo:
|
||||
raise ListStatusNotFoundError(list_id=list_id)
|
||||
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert excinfo.value.detail == f"Status for list {list_id} not found"
|
||||
|
||||
def test_conflict_error():
|
||||
detail_msg = "Resource version mismatch."
|
||||
with pytest.raises(ConflictError) as excinfo:
|
||||
raise ConflictError(detail=detail_msg)
|
||||
assert excinfo.value.status_code == status.HTTP_409_CONFLICT
|
||||
assert excinfo.value.detail == detail_msg
|
||||
|
||||
# Tests for auth-related exceptions that likely require mocking app.config.settings
|
||||
|
||||
# def test_invalid_credentials_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_INVALID_CREDENTIALS", "Invalid test credentials")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
|
||||
# with pytest.raises(InvalidCredentialsError) as excinfo:
|
||||
# raise InvalidCredentialsError()
|
||||
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
# assert excinfo.value.detail == "Invalid test credentials"
|
||||
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_credentials\""}
|
||||
|
||||
# def test_not_authenticated_error(mocker):
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_NOT_AUTHENTICATED", "Not authenticated test")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
|
||||
# with pytest.raises(NotAuthenticatedError) as excinfo:
|
||||
# raise NotAuthenticatedError()
|
||||
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
# assert excinfo.value.detail == "Not authenticated test"
|
||||
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"not_authenticated\""}
|
||||
|
||||
# def test_jwt_error(mocker):
|
||||
# error_msg = "Test JWT issue"
|
||||
# mocker.patch("app.core.exceptions.settings.JWT_ERROR", "JWT error: {error}")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
|
||||
# with pytest.raises(JWTError) as excinfo:
|
||||
# raise JWTError(error=error_msg)
|
||||
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
# assert excinfo.value.detail == f"JWT error: {error_msg}"
|
||||
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}
|
||||
|
||||
# def test_jwt_unexpected_error(mocker):
|
||||
# error_msg = "Unexpected test JWT issue"
|
||||
# mocker.patch("app.core.exceptions.settings.JWT_UNEXPECTED_ERROR", "Unexpected JWT error: {error}")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
|
||||
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
|
||||
# with pytest.raises(JWTUnexpectedError) as excinfo:
|
||||
# raise JWTUnexpectedError(error=error_msg)
|
||||
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
# assert excinfo.value.detail == f"Unexpected JWT error: {error_msg}"
|
||||
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}
|
276
be/tests/core/test_gemini.py
Normal file
276
be/tests/core/test_gemini.py
Normal file
@ -0,0 +1,276 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.api_core import exceptions as google_exceptions
|
||||
|
||||
# Modules to test
|
||||
from app.core import gemini
|
||||
from app.core.exceptions import (
|
||||
OCRServiceUnavailableError,
|
||||
OCRServiceConfigError,
|
||||
OCRUnexpectedError,
|
||||
OCRQuotaExceededError
|
||||
)
|
||||
|
||||
# Default Mock Settings
|
||||
@pytest.fixture
|
||||
def mock_gemini_settings():
|
||||
settings_mock = MagicMock()
|
||||
settings_mock.GEMINI_API_KEY = "test_api_key"
|
||||
settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision"
|
||||
settings_mock.GEMINI_SAFETY_SETTINGS = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
}
|
||||
settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7}
|
||||
settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:"
|
||||
return settings_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_generative_model_instance():
|
||||
model_instance = MagicMock(spec=genai.GenerativeModel)
|
||||
model_instance.generate_content_async = AsyncMock()
|
||||
return model_instance
|
||||
|
||||
@pytest.fixture
|
||||
@patch('google.generativeai.GenerativeModel')
|
||||
@patch('google.generativeai.configure')
|
||||
def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance):
|
||||
mock_generative_model.return_value = mock_generative_model_instance
|
||||
return mock_configure, mock_generative_model, mock_generative_model_instance
|
||||
|
||||
|
||||
# --- Test Gemini Client Initialization (Global Client) ---
|
||||
|
||||
# Parametrize to test different scenarios for the global client init
|
||||
@pytest.mark.parametrize(
|
||||
"api_key_present, configure_raises, model_init_raises, expected_error_message_part",
|
||||
[
|
||||
(True, None, None, None), # Success
|
||||
(False, None, None, "GEMINI_API_KEY not configured"), # API key missing
|
||||
(True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error
|
||||
(True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error
|
||||
]
|
||||
)
|
||||
@patch('app.core.gemini.genai') # Patch genai within the gemini module
|
||||
def test_global_gemini_client_initialization(
|
||||
mock_genai_module,
|
||||
mock_gemini_settings,
|
||||
api_key_present,
|
||||
configure_raises,
|
||||
model_init_raises,
|
||||
expected_error_message_part
|
||||
):
|
||||
"""Tests the global gemini_flash_client initialization logic in app.core.gemini."""
|
||||
# We need to reload the module to re-trigger its top-level initialization code.
|
||||
# This is a bit tricky. A common pattern is to put init logic in a function.
|
||||
# For now, we'll try to simulate it by controlling mocks before module access.
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||
if not api_key_present:
|
||||
mock_gemini_settings.GEMINI_API_KEY = None
|
||||
|
||||
mock_genai_module.configure = MagicMock()
|
||||
mock_genai_module.GenerativeModel = MagicMock()
|
||||
mock_genai_module.types = genai.types # Keep original types
|
||||
mock_genai_module.HarmCategory = genai.HarmCategory
|
||||
mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold
|
||||
|
||||
if configure_raises:
|
||||
mock_genai_module.configure.side_effect = configure_raises
|
||||
if model_init_raises:
|
||||
mock_genai_module.GenerativeModel.side_effect = model_init_raises
|
||||
|
||||
# Python modules are singletons. To re-run top-level code, we need to unload and reload.
|
||||
# This is generally discouraged. It's better to have an explicit init function.
|
||||
# For this test, we'll check the state variables set by the module's import-time code.
|
||||
import importlib
|
||||
importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py
|
||||
|
||||
if expected_error_message_part:
|
||||
assert gemini.gemini_initialization_error is not None
|
||||
assert expected_error_message_part in gemini.gemini_initialization_error
|
||||
assert gemini.gemini_flash_client is None
|
||||
else:
|
||||
assert gemini.gemini_initialization_error is None
|
||||
assert gemini.gemini_flash_client is not None
|
||||
mock_genai_module.configure.assert_called_once_with(api_key="test_api_key")
|
||||
mock_genai_module.GenerativeModel.assert_called_once()
|
||||
# Could add more assertions about safety_settings and generation_config here
|
||||
|
||||
# Clean up after reload for other tests
|
||||
importlib.reload(gemini)
|
||||
|
||||
# --- Test get_gemini_client ---
|
||||
# Assuming the global client tests above set the stage for these
|
||||
|
||||
@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock)
|
||||
@patch('app.core.gemini.gemini_initialization_error', None)
|
||||
def test_get_gemini_client_success(mock_client_var, mock_error_var):
|
||||
mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client
|
||||
gemini.gemini_flash_client = mock_client_var # Assign the mock
|
||||
gemini.gemini_initialization_error = None
|
||||
client = gemini.get_gemini_client()
|
||||
assert client is not None
|
||||
|
||||
@patch('app.core.gemini.gemini_flash_client', None)
|
||||
@patch('app.core.gemini.gemini_initialization_error', "Test init error")
|
||||
def test_get_gemini_client_init_error(mock_client_var, mock_error_var):
|
||||
gemini.gemini_flash_client = None
|
||||
gemini.gemini_initialization_error = "Test init error"
|
||||
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"):
|
||||
gemini.get_gemini_client()
|
||||
|
||||
@patch('app.core.gemini.gemini_flash_client', None)
|
||||
@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None
|
||||
def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var):
|
||||
gemini.gemini_flash_client = None
|
||||
gemini.gemini_initialization_error = None
|
||||
with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."):
|
||||
gemini.get_gemini_client()
|
||||
|
||||
|
||||
# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_from_image_gemini_success(
|
||||
mock_gemini_settings,
|
||||
mock_generative_model_instance,
|
||||
patch_google_ai_client # This fixture patches google.generativeai for the module
|
||||
):
|
||||
""" Test successful item extraction """
|
||||
# Ensure the global client is mocked to be the one we control
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
|
||||
patch('app.core.gemini.gemini_initialization_error', None):
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
|
||||
# Simulate the structure for safety checks if needed
|
||||
mock_candidate = MagicMock()
|
||||
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
|
||||
mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success
|
||||
mock_candidate.safety_ratings = []
|
||||
mock_response.candidates = [mock_candidate]
|
||||
|
||||
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||
|
||||
image_bytes = b"dummy_image_bytes"
|
||||
mime_type = "image/png"
|
||||
|
||||
items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type)
|
||||
|
||||
mock_generative_model_instance.generate_content_async.assert_called_once_with([
|
||||
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
|
||||
{"mime_type": mime_type, "data": image_bytes}
|
||||
])
|
||||
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_from_image_gemini_client_not_init(
|
||||
mock_gemini_settings
|
||||
):
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||
patch('app.core.gemini.gemini_flash_client', None), \
|
||||
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
|
||||
|
||||
image_bytes = b"dummy_image_bytes"
|
||||
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"):
|
||||
await gemini.extract_items_from_image_gemini(image_bytes)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly
|
||||
async def test_extract_items_from_image_gemini_api_quota_error(
|
||||
mock_get_client,
|
||||
mock_gemini_settings,
|
||||
mock_generative_model_instance
|
||||
):
|
||||
mock_get_client.return_value = mock_generative_model_instance
|
||||
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||
image_bytes = b"dummy_image_bytes"
|
||||
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
|
||||
await gemini.extract_items_from_image_gemini(image_bytes)
|
||||
|
||||
|
||||
# --- Tests for GeminiOCRService --- (Example tests, more needed)
|
||||
|
||||
@patch('app.core.gemini.genai.configure')
|
||||
@patch('app.core.gemini.genai.GenerativeModel')
|
||||
def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance):
|
||||
MockGenerativeModel.return_value = mock_generative_model_instance
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||
service = gemini.GeminiOCRService()
|
||||
MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY)
|
||||
MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME)
|
||||
assert service.model == mock_generative_model_instance
|
||||
# Could add assertions for safety_settings and generation_config if they are set directly on model
|
||||
|
||||
@patch('app.core.gemini.genai.configure')
|
||||
@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed"))
|
||||
def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings):
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||
with pytest.raises(OCRServiceConfigError):
|
||||
gemini.GeminiOCRService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored"
|
||||
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings):
|
||||
# Patch the model instance within the service for this test
|
||||
with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class,
|
||||
patch.object(genai, 'configure') as patched_configure:
|
||||
|
||||
service = gemini.GeminiOCRService() # Re-init to use the patched model
|
||||
items = await service.extract_items(b"dummy_image")
|
||||
|
||||
expected_call_args = [
|
||||
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
|
||||
{"mime_type": "image/jpeg", "data": b"dummy_image"}
|
||||
]
|
||||
service.model.generate_content_async.assert_called_once_with(contents=expected_call_args)
|
||||
assert items == ["Apple", "Banana", "Orange"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance):
|
||||
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.")
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
|
||||
patch.object(genai, 'configure'):
|
||||
|
||||
service = gemini.GeminiOCRService()
|
||||
with pytest.raises(OCRQuotaExceededError):
|
||||
await service.extract_items(b"dummy_image")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance):
|
||||
# Simulate a generic API error that isn't quota related
|
||||
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable")
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
|
||||
patch.object(genai, 'configure'):
|
||||
|
||||
service = gemini.GeminiOCRService()
|
||||
with pytest.raises(OCRServiceUnavailableError):
|
||||
await service.extract_items(b"dummy_image")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance):
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = None # Simulate no text in response
|
||||
mock_generative_model_instance.generate_content_async.return_value = mock_response
|
||||
|
||||
with patch('app.core.gemini.settings', mock_gemini_settings), \
|
||||
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
|
||||
patch.object(genai, 'configure'):
|
||||
|
||||
service = gemini.GeminiOCRService()
|
||||
with pytest.raises(OCRUnexpectedError):
|
||||
await service.extract_items(b"dummy_image")
|
216
be/tests/core/test_security.py
Normal file
216
be/tests/core/test_security.py
Normal file
@ -0,0 +1,216 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
hash_password,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_access_token,
|
||||
verify_refresh_token,
|
||||
pwd_context, # Import for direct testing if needed, or to check its config
|
||||
)
|
||||
# Assuming app.config.settings will be mocked
|
||||
# from app.config import settings
|
||||
|
||||
# --- Tests for Password Hashing ---
|
||||
|
||||
def test_hash_password():
|
||||
password = "securepassword123"
|
||||
hashed = hash_password(password)
|
||||
assert isinstance(hashed, str)
|
||||
assert hashed != password
|
||||
# Check that the default scheme (bcrypt) is used by verifying the hash prefix
|
||||
# bcrypt hashes typically start with $2b$ or $2a$ or $2y$
|
||||
assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$")
|
||||
|
||||
def test_verify_password_correct():
|
||||
password = "testpassword"
|
||||
hashed_password = pwd_context.hash(password) # Use the same context for consistency
|
||||
assert verify_password(password, hashed_password) is True
|
||||
|
||||
def test_verify_password_incorrect():
|
||||
password = "testpassword"
|
||||
wrong_password = "wrongpassword"
|
||||
hashed_password = pwd_context.hash(password)
|
||||
assert verify_password(wrong_password, hashed_password) is False
|
||||
|
||||
def test_verify_password_invalid_hash_format():
|
||||
password = "testpassword"
|
||||
invalid_hash = "notarealhash"
|
||||
assert verify_password(password, invalid_hash) is False
|
||||
|
||||
# --- Tests for JWT Creation ---
|
||||
# Mock settings for JWT tests
|
||||
@pytest.fixture(scope="module")
|
||||
def mock_jwt_settings():
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.SECRET_KEY = "testsecretkey"
|
||||
mock_settings.ALGORITHM = "HS256"
|
||||
mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
|
||||
return mock_settings
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "user@example.com"
|
||||
token = create_access_token(subject)
|
||||
assert isinstance(token, str)
|
||||
|
||||
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||
assert decoded_payload["sub"] == subject
|
||||
assert decoded_payload["type"] == "access"
|
||||
assert "exp" in decoded_payload
|
||||
# Check if expiry is roughly correct (within a small delta)
|
||||
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
# ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
|
||||
|
||||
subject = 123 # Subject can be int
|
||||
custom_delta = timedelta(hours=1)
|
||||
token = create_access_token(subject, expires_delta=custom_delta)
|
||||
assert isinstance(token, str)
|
||||
|
||||
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||
assert decoded_payload["sub"] == str(subject)
|
||||
assert decoded_payload["type"] == "access"
|
||||
expected_expiry = datetime.now(timezone.utc) + custom_delta
|
||||
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "refresh_subject"
|
||||
token = create_refresh_token(subject)
|
||||
assert isinstance(token, str)
|
||||
|
||||
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
|
||||
assert decoded_payload["sub"] == subject
|
||||
assert decoded_payload["type"] == "refresh"
|
||||
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
|
||||
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
|
||||
|
||||
# --- Tests for JWT Verification --- (More tests to be added here)
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "test_user_valid_access"
|
||||
token = create_access_token(subject)
|
||||
payload = verify_access_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == subject
|
||||
assert payload["type"] == "access"
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "test_user_invalid_sig"
|
||||
# Create token with correct key
|
||||
token = create_access_token(subject)
|
||||
|
||||
# Try to verify with wrong key
|
||||
mock_settings_global.SECRET_KEY = "wrongsecretkey"
|
||||
payload = verify_access_token(token)
|
||||
assert payload is None
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
@patch('app.core.security.datetime') # Mock datetime to control token expiry
|
||||
def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
|
||||
|
||||
# Set current time for token creation
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_datetime.now.return_value = now
|
||||
mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
|
||||
mock_datetime.timedelta = timedelta # Ensure original timedelta is used
|
||||
|
||||
subject = "test_user_expired"
|
||||
token = create_access_token(subject)
|
||||
|
||||
# Advance time beyond expiry for verification
|
||||
mock_datetime.now.return_value = now + timedelta(minutes=5)
|
||||
payload = verify_access_token(token)
|
||||
assert payload is None
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
|
||||
|
||||
subject = "test_user_wrong_type"
|
||||
# Create a refresh token
|
||||
refresh_token = create_refresh_token(subject)
|
||||
|
||||
# Try to verify it as an access token
|
||||
payload = verify_access_token(refresh_token)
|
||||
assert payload is None
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "test_user_valid_refresh"
|
||||
token = create_refresh_token(subject)
|
||||
payload = verify_refresh_token(token)
|
||||
assert payload is not None
|
||||
assert payload["sub"] == subject
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
@patch('app.core.security.datetime')
|
||||
def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
mock_datetime.now.return_value = now
|
||||
mock_datetime.fromtimestamp = datetime.fromtimestamp
|
||||
mock_datetime.timedelta = timedelta
|
||||
|
||||
subject = "test_user_expired_refresh"
|
||||
token = create_refresh_token(subject)
|
||||
|
||||
mock_datetime.now.return_value = now + timedelta(minutes=5)
|
||||
payload = verify_refresh_token(token)
|
||||
assert payload is None
|
||||
|
||||
@patch('app.core.security.settings')
|
||||
def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
|
||||
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
|
||||
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
|
||||
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
subject = "test_user_wrong_type_refresh"
|
||||
access_token = create_access_token(subject)
|
||||
|
||||
payload = verify_refresh_token(access_token)
|
||||
assert payload is None
|
1
be/tests/crud/__init__.py
Normal file
1
be/tests/crud/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
254
be/tests/crud/test_cost.py
Normal file
254
be/tests/crud/test_cost.py
Normal file
@ -0,0 +1,254 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List as PyList, Dict
|
||||
import logging # For mocking the logger
|
||||
|
||||
from app.crud.cost import (
|
||||
calculate_list_cost_summary,
|
||||
calculate_suggested_settlements,
|
||||
calculate_group_balance_summary
|
||||
)
|
||||
from app.schemas.cost import (
|
||||
ListCostSummary,
|
||||
UserCostShare,
|
||||
GroupBalanceSummary,
|
||||
UserBalanceDetail,
|
||||
SuggestedSettlement
|
||||
)
|
||||
from app.models import (
|
||||
List as ListModel,
|
||||
Item as ItemModel,
|
||||
User as UserModel,
|
||||
Group as GroupModel,
|
||||
UserGroup as UserGroupModel,
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
Settlement as SettlementModel
|
||||
)
|
||||
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def user_a():
|
||||
return UserModel(id=1, name="User A", email="a@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def user_b():
|
||||
return UserModel(id=2, name="User B", email="b@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def user_c():
|
||||
return UserModel(id=3, name="User C", email="c@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def basic_list_model(user_a):
|
||||
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
|
||||
|
||||
@pytest.fixture
|
||||
def group_model_with_members(user_a, user_b):
|
||||
group = GroupModel(id=1, name="Test Group")
|
||||
# Simulate user_associations for calculate_list_cost_summary if list is group-based
|
||||
group.user_associations = [
|
||||
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
|
||||
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
|
||||
]
|
||||
return group
|
||||
|
||||
@pytest.fixture
|
||||
def list_model_with_group(group_model_with_members):
|
||||
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
|
||||
|
||||
# --- calculate_list_cost_summary Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
|
||||
basic_list_model.items = [] # Ensure no items
|
||||
basic_list_model.group = None # Ensure it's a personal list
|
||||
basic_list_model.creator = user_a
|
||||
|
||||
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
|
||||
|
||||
assert summary.list_id == basic_list_model.id
|
||||
assert summary.total_list_cost == Decimal("0.00")
|
||||
assert summary.num_participating_users == 1
|
||||
assert summary.equal_share_per_user == Decimal("0.00")
|
||||
assert len(summary.user_balances) == 1
|
||||
assert summary.user_balances[0].user_id == user_a.id
|
||||
assert summary.user_balances[0].balance == Decimal("0.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
|
||||
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
|
||||
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
|
||||
list_model_with_group.items = [item1, item2]
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
|
||||
|
||||
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
|
||||
|
||||
assert summary.total_list_cost == Decimal("5.00")
|
||||
assert summary.num_participating_users == 2
|
||||
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
|
||||
|
||||
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
||||
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
|
||||
assert balances_map[user_a.id].amount_due == Decimal("2.50")
|
||||
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
|
||||
|
||||
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
|
||||
assert balances_map[user_b.id].amount_due == Decimal("2.50")
|
||||
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
with pytest.raises(ListNotFoundError):
|
||||
await calculate_list_cost_summary(mock_db_session, 999)
|
||||
|
||||
# --- calculate_suggested_settlements Tests ---
|
||||
@patch('app.crud.cost.logger') # Mock the logger used in the function
|
||||
def test_calculate_suggested_settlements_simple_case(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
assert len(settlements) == 1
|
||||
assert settlements[0].from_user_id == 1
|
||||
assert settlements[0].to_user_id == 2
|
||||
assert settlements[0].amount == Decimal("10.00")
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
@patch('app.crud.cost.logger')
|
||||
def test_calculate_suggested_settlements_multiple(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
||||
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
# Expected: A owes B 10, A owes C 20
|
||||
assert len(settlements) == 2
|
||||
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
|
||||
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
|
||||
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
@patch('app.crud.cost.logger')
|
||||
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
|
||||
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
# A -> C 33.33, B -> C 33.33
|
||||
assert len(settlements) == 2
|
||||
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
|
||||
assert total_settled_to_C == Decimal("66.66")
|
||||
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
|
||||
|
||||
# --- calculate_group_balance_summary Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
|
||||
mock_db_session.get.return_value = group_model_with_members # For group fetch
|
||||
|
||||
# Mock group members query
|
||||
mock_members_result = AsyncMock()
|
||||
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
|
||||
|
||||
# Mock expenses query (no expenses)
|
||||
mock_expenses_result = AsyncMock()
|
||||
mock_expenses_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Mock settlements query (no settlements)
|
||||
mock_settlements_result = AsyncMock()
|
||||
mock_settlements_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
||||
|
||||
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
||||
|
||||
assert summary.group_id == group_model_with_members.id
|
||||
assert len(summary.user_balances) == 2
|
||||
assert len(summary.suggested_settlements) == 0
|
||||
for ub in summary.user_balances:
|
||||
assert ub.net_balance == Decimal("0.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
|
||||
# Group members for this test: A, B, C
|
||||
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
|
||||
all_users = [user_a, user_b, user_c]
|
||||
|
||||
mock_db_session.get.return_value = group_model_with_members
|
||||
|
||||
mock_members_result = AsyncMock()
|
||||
mock_members_result.scalars.return_value.all.return_value = all_users
|
||||
|
||||
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
|
||||
# A: paid 90, share 30 -> +60
|
||||
# B: paid 0, share 30 -> -30
|
||||
# C: paid 0, share 30 -> -30
|
||||
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
|
||||
expense1.splits = [
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
|
||||
]
|
||||
mock_expenses_result = AsyncMock()
|
||||
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
|
||||
|
||||
# Settlements: B paid A 10
|
||||
# A: received 10 -> +60 + 10 = +70
|
||||
# B: paid 10 -> -30 - 10 = -40
|
||||
# C: no settlement -> -30
|
||||
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
|
||||
mock_settlements_result = AsyncMock()
|
||||
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
||||
|
||||
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
|
||||
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
||||
|
||||
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
||||
assert balances_map[user_a.id].net_balance == Decimal("70.00")
|
||||
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
|
||||
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
|
||||
|
||||
# Suggested settlements: B owes A 30 (40-10), C owes A 30
|
||||
# Net balances: A: +70, B: -40, C: -30
|
||||
# Algorithm should suggest: B -> A (40), C -> A (30)
|
||||
assert len(summary.suggested_settlements) == 2
|
||||
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
|
||||
|
||||
# Order might vary, check presence and amounts
|
||||
assert (user_b.id, user_a.id) in settlement_map
|
||||
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
|
||||
|
||||
assert (user_c.id, user_a.id) in settlement_map
|
||||
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
|
||||
mock_settlement_logger.warning.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
|
||||
mock_db_session.get.return_value = None
|
||||
with pytest.raises(GroupNotFoundError):
|
||||
await calculate_group_balance_summary(mock_db_session, 999)
|
||||
|
||||
# TODO:
|
||||
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
|
||||
# - Test calculate_list_cost_summary with remainder distribution for equal share.
|
||||
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
|
||||
# - Test calculate_suggested_settlements with logger warning for remaining balances.
|
||||
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
|
||||
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).
|
317
be/tests/crud/test_expense.py
Normal file
317
be/tests/crud/test_expense.py
Normal file
@ -0,0 +1,317 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timezone
|
||||
from typing import List as PyList, Optional
|
||||
|
||||
from app.crud.expense import (
|
||||
create_expense,
|
||||
get_expense_by_id,
|
||||
get_expenses_for_list,
|
||||
get_expenses_for_group,
|
||||
update_expense, # Assuming update_expense exists
|
||||
delete_expense, # Assuming delete_expense exists
|
||||
get_users_for_splitting # Helper, might test indirectly
|
||||
)
|
||||
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
|
||||
from app.models import (
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
User as UserModel,
|
||||
List as ListModel,
|
||||
Group as GroupModel,
|
||||
UserGroup as UserGroupModel,
|
||||
Item as ItemModel,
|
||||
SplitTypeEnum
|
||||
)
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
GroupNotFoundError,
|
||||
UserNotFoundError,
|
||||
InvalidOperationError
|
||||
)
|
||||
|
||||
# General Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
session.flush = AsyncMock() # create_expense uses flush
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def basic_user_model():
|
||||
return UserModel(id=1, name="Test User", email="test@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def another_user_model():
|
||||
return UserModel(id=2, name="Another User", email="another@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def basic_group_model():
|
||||
group = GroupModel(id=1, name="Test Group")
|
||||
# Simulate member_associations for get_users_for_splitting if needed directly
|
||||
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
|
||||
return group
|
||||
|
||||
@pytest.fixture
|
||||
def basic_list_model(basic_group_model, basic_user_model):
|
||||
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)
|
||||
|
||||
@pytest.fixture
|
||||
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
|
||||
return ExpenseCreate(
|
||||
description="Grocery run",
|
||||
total_amount=Decimal("30.00"),
|
||||
currency="USD",
|
||||
expense_date=datetime.now(timezone.utc),
|
||||
split_type=SplitTypeEnum.EQUAL,
|
||||
list_id=basic_list_model.id,
|
||||
group_id=None, # Derived from list
|
||||
item_id=None,
|
||||
paid_by_user_id=basic_user_model.id,
|
||||
splits_in=None
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
|
||||
return ExpenseCreate(
|
||||
description="Movies",
|
||||
total_amount=Decimal("50.00"),
|
||||
currency="USD",
|
||||
expense_date=datetime.now(timezone.utc),
|
||||
split_type=SplitTypeEnum.EQUAL,
|
||||
list_id=None,
|
||||
group_id=basic_group_model.id,
|
||||
item_id=None,
|
||||
paid_by_user_id=basic_user_model.id,
|
||||
splits_in=None
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
|
||||
return ExpenseCreate(
|
||||
description="Dinner",
|
||||
total_amount=Decimal("100.00"),
|
||||
split_type=SplitTypeEnum.EXACT_AMOUNTS,
|
||||
group_id=basic_group_model.id,
|
||||
paid_by_user_id=basic_user_model.id,
|
||||
splits_in=[
|
||||
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
||||
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
|
||||
return ExpenseModel(
|
||||
id=1,
|
||||
description=expense_create_data_equal_split_group_ctx.description,
|
||||
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
|
||||
currency=expense_create_data_equal_split_group_ctx.currency,
|
||||
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
|
||||
split_type=expense_create_data_equal_split_group_ctx.split_type,
|
||||
list_id=expense_create_data_equal_split_group_ctx.list_id,
|
||||
group_id=expense_create_data_equal_split_group_ctx.group_id,
|
||||
item_id=expense_create_data_equal_split_group_ctx.item_id,
|
||||
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
|
||||
paid_by=basic_user_model, # Assuming paid_by relation is loaded
|
||||
# splits would be populated after creation usually
|
||||
version=1
|
||||
)
|
||||
|
||||
# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
|
||||
# Setup group with members
|
||||
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
|
||||
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
|
||||
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
||||
|
||||
mock_execute = AsyncMock()
|
||||
mock_execute.scalars.return_value.first.return_value = basic_group_model
|
||||
mock_db_session.execute.return_value = mock_execute
|
||||
|
||||
users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
|
||||
assert len(users) == 2
|
||||
assert basic_user_model in users
|
||||
assert another_user_model in users
|
||||
|
||||
# --- create_expense Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
|
||||
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
|
||||
|
||||
# Mock get_users_for_splitting call within create_expense
|
||||
# This is a bit tricky as it's an internal call. Patching is an option.
|
||||
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
|
||||
mock_get_users.return_value = [basic_user_model, another_user_model]
|
||||
|
||||
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
|
||||
|
||||
mock_db_session.add.assert_called()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
# mock_db_session.commit.assert_called_once() # create_expense does not commit itself
|
||||
# mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself
|
||||
|
||||
assert created_expense is not None
|
||||
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
|
||||
assert created_expense.split_type == SplitTypeEnum.EQUAL
|
||||
assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance
|
||||
|
||||
# Check split amounts
|
||||
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
for split in created_expense.splits:
|
||||
assert split.owed_amount == expected_amount_per_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
|
||||
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
|
||||
|
||||
# Mock the select for user validation in exact splits
|
||||
mock_user_select_result = AsyncMock()
|
||||
mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples
|
||||
# To make it behave like scalars().all() that returns a list of IDs:
|
||||
# We need to mock the scalars().all() part, or the whole execute chain for user validation.
|
||||
# A simpler way for this specific case might be to mock the select for User.id
|
||||
mock_execute_user_ids = AsyncMock()
|
||||
# Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process
|
||||
# It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}`
|
||||
# Let's assume the select returns a list of Row objects or tuples with one element
|
||||
mock_user_ids_result_proxy = MagicMock()
|
||||
mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)])
|
||||
mock_db_session.execute.return_value = mock_user_ids_result_proxy
|
||||
|
||||
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
|
||||
|
||||
mock_db_session.add.assert_called()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
assert created_expense is not None
|
||||
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
|
||||
assert len(created_expense.splits) == 2
|
||||
assert created_expense.splits[0].owed_amount == Decimal("60.00")
|
||||
assert created_expense.splits[1].owed_amount == Decimal("40.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
|
||||
mock_db_session.get.return_value = None # Payer not found
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
|
||||
mock_db_session.get.return_value = basic_user_model # Payer found
|
||||
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
|
||||
expense_data.list_id = None
|
||||
expense_data.group_id = None
|
||||
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
|
||||
await create_expense(mock_db_session, expense_data, 1)
|
||||
|
||||
# --- get_expense_by_id Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_expense_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
expense = await get_expense_by_id(mock_db_session, 1)
|
||||
assert expense is not None
|
||||
assert expense.id == 1
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expense_by_id_not_found(mock_db_session):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
expense = await get_expense_by_id(mock_db_session, 999)
|
||||
assert expense is None
|
||||
|
||||
# --- get_expenses_for_list Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
|
||||
assert len(expenses) == 1
|
||||
assert expenses[0].id == db_expense_model.id
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# --- get_expenses_for_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
|
||||
assert len(expenses) == 1
|
||||
assert expenses[0].id == db_expense_model.id
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# --- Stubs for update_expense and delete_expense ---
|
||||
# These will need more details once the actual implementation of update/delete is clear
|
||||
# For example, how splits are handled on update, versioning, etc.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_expense_stub(mock_db_session):
|
||||
# Placeholder: Test logic for update_expense will be more complex
|
||||
# Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
|
||||
# Also depends on what fields are updatable and how splits are managed.
|
||||
expense_to_update = MagicMock(spec=ExpenseModel)
|
||||
expense_to_update.version = 1
|
||||
update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
|
||||
|
||||
# Simulate the update_expense function behavior
|
||||
# For example, if it loads the expense, modifies, commits, refreshes:
|
||||
# mock_db_session.get.return_value = expense_to_update
|
||||
# updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
|
||||
# assert updated_expense.description == "New description"
|
||||
# mock_db_session.commit.assert_called_once()
|
||||
# mock_db_session.refresh.assert_called_once()
|
||||
pass # Replace with actual test logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_expense_stub(mock_db_session):
|
||||
# Placeholder: Test logic for delete_expense
|
||||
# Needs an existing expense object and mocking of delete/commit
|
||||
# Also, consider implications (e.g., are splits deleted?)
|
||||
expense_to_delete = MagicMock(spec=ExpenseModel)
|
||||
expense_to_delete.id = 1
|
||||
expense_to_delete.version = 1
|
||||
|
||||
# Simulate delete_expense behavior
|
||||
# mock_db_session.get.return_value = expense_to_delete # If it re-fetches
|
||||
# await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
|
||||
# mock_db_session.delete.assert_called_once_with(expense_to_delete)
|
||||
# mock_db_session.commit.assert_called_once()
|
||||
pass # Replace with actual test logic
|
||||
|
||||
# TODO: Add more tests for create_expense covering:
|
||||
# - List context success
|
||||
# - Percentage, Shares, Item-based splits
|
||||
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
|
||||
# - Validation of list_id/group_id consistency
|
||||
# - User not found in splits_in
|
||||
# - Item not found for ITEM_BASED split
|
||||
|
||||
# TODO: Flesh out update_expense tests:
|
||||
# - Success case
|
||||
# - Version mismatch
|
||||
# - Trying to update immutable fields
|
||||
# - How splits are handled (recalculated, deleted/recreated, or not changeable)
|
||||
|
||||
# TODO: Flesh out delete_expense tests:
|
||||
# - Success case
|
||||
# - Version mismatch (if applicable)
|
||||
# - Ensure associated splits are also deleted (cascade behavior)
|
270
be/tests/crud/test_group.py
Normal file
270
be/tests/crud/test_group.py
Normal file
@ -0,0 +1,270 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import delete, func # For remove_user_from_group and get_group_member_count
|
||||
|
||||
from app.crud.group import (
|
||||
create_group,
|
||||
get_user_groups,
|
||||
get_group_by_id,
|
||||
is_user_member,
|
||||
get_user_role_in_group,
|
||||
add_user_to_group,
|
||||
remove_user_from_group,
|
||||
get_group_member_count,
|
||||
check_group_membership,
|
||||
check_user_role_in_group
|
||||
)
|
||||
from app.schemas.group import GroupCreate
|
||||
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum
|
||||
from app.core.exceptions import (
|
||||
GroupOperationError,
|
||||
GroupNotFoundError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
GroupMembershipError,
|
||||
GroupPermissionError
|
||||
)
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
# Patch begin_nested for SQLAlchemy 1.4+ if used, or just begin() if that's the pattern
|
||||
# For simplicity, assuming `async with db.begin():` translates to db.begin() and db.commit()/rollback()
|
||||
session.begin = AsyncMock() # Mock the begin call used in async with db.begin()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock() # For remove_user_from_group (if it uses session.delete)
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def group_create_data():
|
||||
return GroupCreate(name="Test Group")
|
||||
|
||||
@pytest.fixture
|
||||
def creator_user_model():
|
||||
return UserModel(id=1, name="Creator User", email="creator@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def member_user_model():
|
||||
return UserModel(id=2, name="Member User", email="member@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def db_group_model(creator_user_model):
|
||||
return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model)
|
||||
|
||||
@pytest.fixture
|
||||
def db_user_group_owner_assoc(db_group_model, creator_user_model):
|
||||
return UserGroupModel(user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model)
|
||||
|
||||
@pytest.fixture
|
||||
def db_user_group_member_assoc(db_group_model, member_user_model):
|
||||
return UserGroupModel(user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model)
|
||||
|
||||
# --- create_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_group_success(mock_db_session, group_create_data, creator_user_model):
|
||||
async def mock_refresh(instance):
|
||||
instance.id = 1 # Simulate ID assignment by DB
|
||||
return None
|
||||
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||
|
||||
created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id)
|
||||
|
||||
assert mock_db_session.add.call_count == 2 # Group and UserGroup
|
||||
mock_db_session.flush.assert_called() # Called multiple times
|
||||
mock_db_session.refresh.assert_called_once_with(created_group)
|
||||
assert created_group is not None
|
||||
assert created_group.name == group_create_data.name
|
||||
assert created_group.created_by_id == creator_user_model.id
|
||||
# Further check if UserGroup was created correctly by inspecting mock_db_session.add calls or by fetching
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model):
|
||||
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
|
||||
with pytest.raises(DatabaseIntegrityError):
|
||||
await create_group(mock_db_session, group_create_data, creator_user_model.id)
|
||||
mock_db_session.rollback.assert_called_once() # Assuming rollback within the except block of create_group
|
||||
|
||||
# --- get_user_groups Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.all.return_value = [db_group_model]
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
groups = await get_user_groups(mock_db_session, creator_user_model.id)
|
||||
assert len(groups) == 1
|
||||
assert groups[0].name == db_group_model.name
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# --- get_group_by_id Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_by_id_found(mock_db_session, db_group_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_group_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
group = await get_group_by_id(mock_db_session, db_group_model.id)
|
||||
assert group is not None
|
||||
assert group.id == db_group_model.id
|
||||
# Add assertions for eager loaded members if applicable and mocked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_by_id_not_found(mock_db_session):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
group = await get_group_by_id(mock_db_session, 999)
|
||||
assert group is None
|
||||
|
||||
# --- is_user_member Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalar_one_or_none.return_value = 1 # Simulate UserGroup.id found
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||
assert is_member is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_user_member_false(mock_db_session, db_group_model, member_user_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalar_one_or_none.return_value = None # Simulate no UserGroup.id found
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
is_member = await is_user_member(mock_db_session, db_group_model.id, member_user_model.id + 1) # Non-member
|
||||
assert is_member is False
|
||||
|
||||
# --- get_user_role_in_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||
assert role == UserRoleEnum.owner
|
||||
|
||||
# --- add_user_to_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model):
|
||||
# First execute call for checking existing membership returns None
|
||||
mock_existing_check_result = AsyncMock()
|
||||
mock_existing_check_result.scalar_one_or_none.return_value = None
|
||||
mock_db_session.execute.return_value = mock_existing_check_result
|
||||
|
||||
async def mock_refresh_user_group(instance):
|
||||
instance.id = 100 # Simulate ID for UserGroupModel
|
||||
return None
|
||||
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh_user_group)
|
||||
|
||||
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.member)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once()
|
||||
assert user_group_assoc is not None
|
||||
assert user_group_assoc.user_id == member_user_model.id
|
||||
assert user_group_assoc.group_id == db_group_model.id
|
||||
assert user_group_assoc.role == UserRoleEnum.member
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc):
|
||||
mock_existing_check_result = AsyncMock()
|
||||
mock_existing_check_result.scalar_one_or_none.return_value = db_user_group_owner_assoc # User is already a member
|
||||
mock_db_session.execute.return_value = mock_existing_check_result
|
||||
|
||||
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||
assert user_group_assoc is None
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
# --- remove_user_from_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model):
|
||||
mock_delete_result = AsyncMock()
|
||||
mock_delete_result.scalar_one_or_none.return_value = 1 # Simulate a row was deleted (returning ID)
|
||||
mock_db_session.execute.return_value = mock_delete_result
|
||||
|
||||
removed = await remove_user_from_group(mock_db_session, db_group_model.id, member_user_model.id)
|
||||
assert removed is True
|
||||
# Assert that db.execute was called with a delete statement
|
||||
# This requires inspecting the call args of mock_db_session.execute
|
||||
# For simplicity, we check it was called. A deeper check would validate the SQL query itself.
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# --- get_group_member_count Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_member_count_success(mock_db_session, db_group_model):
|
||||
mock_count_result = AsyncMock()
|
||||
mock_count_result.scalar_one.return_value = 5
|
||||
mock_db_session.execute.return_value = mock_count_result
|
||||
count = await get_group_member_count(mock_db_session, db_group_model.id)
|
||||
assert count == 5
|
||||
|
||||
# --- check_group_membership Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model):
|
||||
mock_db_session.get.return_value = db_group_model # Group exists
|
||||
mock_membership_result = AsyncMock()
|
||||
mock_membership_result.scalar_one_or_none.return_value = 1 # User is a member
|
||||
mock_db_session.execute.return_value = mock_membership_result
|
||||
|
||||
await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id)
|
||||
# No exception means success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model):
|
||||
mock_db_session.get.return_value = None # Group does not exist
|
||||
with pytest.raises(GroupNotFoundError):
|
||||
await check_group_membership(mock_db_session, 999, creator_user_model.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_group_membership_not_member(mock_db_session, db_group_model, member_user_model):
|
||||
mock_db_session.get.return_value = db_group_model # Group exists
|
||||
mock_membership_result = AsyncMock()
|
||||
mock_membership_result.scalar_one_or_none.return_value = None # User is not a member
|
||||
mock_db_session.execute.return_value = mock_membership_result
|
||||
with pytest.raises(GroupMembershipError):
|
||||
await check_group_membership(mock_db_session, db_group_model.id, member_user_model.id)
|
||||
|
||||
# --- check_user_role_in_group Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model):
|
||||
# Mock check_group_membership (implicitly called)
|
||||
mock_db_session.get.return_value = db_group_model
|
||||
mock_membership_check = AsyncMock()
|
||||
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member
|
||||
|
||||
# Mock get_user_role_in_group
|
||||
mock_role_check = AsyncMock()
|
||||
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.owner
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
|
||||
|
||||
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member)
|
||||
# No exception means success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model):
|
||||
mock_db_session.get.return_value = db_group_model # Group exists
|
||||
mock_membership_check = AsyncMock()
|
||||
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member (for check_group_membership call)
|
||||
|
||||
mock_role_check = AsyncMock()
|
||||
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.member # User's actual role
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
|
||||
|
||||
with pytest.raises(GroupPermissionError):
|
||||
await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner)
|
||||
|
||||
# TODO: Add tests for DB operational/SQLAlchemy errors for each function similar to create_group_integrity_error
|
||||
# TODO: Test edge cases like trying to add user to non-existent group (should be caught by FK constraints or prior checks)
|
174
be/tests/crud/test_invite.py
Normal file
174
be/tests/crud/test_invite.py
Normal file
@ -0,0 +1,174 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import secrets
|
||||
|
||||
from app.crud.invite import (
|
||||
create_invite,
|
||||
get_active_invite_by_code,
|
||||
deactivate_invite,
|
||||
MAX_CODE_GENERATION_ATTEMPTS
|
||||
)
|
||||
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
|
||||
# No specific schemas for invite CRUD usually, but models are used.
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def group_model():
|
||||
return GroupModel(id=1, name="Test Group")
|
||||
|
||||
@pytest.fixture
|
||||
def user_model(): # Creator
|
||||
return UserModel(id=1, name="Creator User")
|
||||
|
||||
@pytest.fixture
|
||||
def db_invite_model(group_model, user_model):
|
||||
return InviteModel(
|
||||
id=1,
|
||||
code="test_invite_code_123",
|
||||
group_id=group_model.id,
|
||||
created_by_id=user_model.id,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# --- create_invite Tests ---
|
||||
@pytest.mark.asyncio
|
||||
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
|
||||
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||
generated_code = "unique_code_123"
|
||||
mock_token_urlsafe.return_value = generated_code
|
||||
|
||||
# Mock DB execute for checking existing code (first attempt, no existing code)
|
||||
mock_existing_check_result = AsyncMock()
|
||||
mock_existing_check_result.scalar_one_or_none.return_value = None
|
||||
mock_db_session.execute.return_value = mock_existing_check_result
|
||||
|
||||
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
|
||||
|
||||
mock_token_urlsafe.assert_called_once_with(16)
|
||||
mock_db_session.execute.assert_called_once() # For the uniqueness check
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once_with(invite)
|
||||
|
||||
assert invite is not None
|
||||
assert invite.code == generated_code
|
||||
assert invite.group_id == group_model.id
|
||||
assert invite.created_by_id == user_model.id
|
||||
assert invite.is_active is True
|
||||
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('app.crud.invite.secrets.token_urlsafe')
|
||||
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||
colliding_code = "colliding_code"
|
||||
unique_code = "finally_unique_code"
|
||||
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
|
||||
|
||||
# Mock DB execute for checking existing code
|
||||
mock_collision_check_result = AsyncMock()
|
||||
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
|
||||
|
||||
mock_no_collision_check_result = AsyncMock()
|
||||
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
|
||||
|
||||
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
||||
|
||||
assert mock_token_urlsafe.call_count == 2
|
||||
assert mock_db_session.execute.call_count == 2
|
||||
assert invite is not None
|
||||
assert invite.code == unique_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('app.crud.invite.secrets.token_urlsafe')
|
||||
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
|
||||
mock_token_urlsafe.return_value = "always_colliding_code"
|
||||
|
||||
mock_collision_check_result = AsyncMock()
|
||||
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
|
||||
mock_db_session.execute.return_value = mock_collision_check_result
|
||||
|
||||
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
|
||||
|
||||
assert invite is None
|
||||
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
||||
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
# --- get_active_invite_by_code Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
|
||||
db_invite_model.is_active = True
|
||||
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_invite_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||
assert invite is not None
|
||||
assert invite.code == db_invite_model.code
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_invite_by_code_not_found(mock_db_session):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
|
||||
assert invite is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
|
||||
db_invite_model.is_active = False # Inactive
|
||||
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
|
||||
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||
assert invite is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
|
||||
db_invite_model.is_active = True
|
||||
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
|
||||
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
|
||||
assert invite is None
|
||||
|
||||
# --- deactivate_invite Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
|
||||
db_invite_model.is_active = True # Ensure it starts active
|
||||
|
||||
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
|
||||
|
||||
mock_db_session.add.assert_called_once_with(db_invite_model)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once_with(db_invite_model)
|
||||
assert deactivated_invite.is_active is False
|
||||
|
||||
# It might be useful to test DB error cases (OperationalError, etc.) for each function
|
||||
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
|
||||
# create_invite has its own DB interaction within the loop, so that's covered.
|
||||
# get_active_invite_by_code and deactivate_invite are simpler DB ops.
|
184
be/tests/crud/test_item.py
Normal file
184
be/tests/crud/test_item.py
Normal file
@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.crud.item import (
|
||||
create_item,
|
||||
get_items_by_list_id,
|
||||
get_item_by_id,
|
||||
update_item,
|
||||
delete_item
|
||||
)
|
||||
from app.schemas.item import ItemCreate, ItemUpdate
|
||||
from app.models import Item as ItemModel, User as UserModel, List as ListModel
|
||||
from app.core.exceptions import (
|
||||
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
ConflictError
|
||||
)
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.begin = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def item_create_data():
|
||||
return ItemCreate(name="Test Item", quantity="1 pack")
|
||||
|
||||
@pytest.fixture
|
||||
def item_update_data():
|
||||
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
|
||||
|
||||
@pytest.fixture
|
||||
def user_model():
|
||||
return UserModel(id=1, name="Test User", email="test@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def list_model():
|
||||
return ListModel(id=1, name="Test List")
|
||||
|
||||
@pytest.fixture
|
||||
def db_item_model(list_model, user_model):
|
||||
return ItemModel(
|
||||
id=1,
|
||||
name="Existing Item",
|
||||
quantity="1 unit",
|
||||
list_id=list_model.id,
|
||||
added_by_id=user_model.id,
|
||||
is_complete=False,
|
||||
version=1,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
# --- create_item Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
|
||||
async def mock_refresh(instance):
|
||||
instance.id = 10 # Simulate ID assignment
|
||||
instance.version = 1 # Simulate version init
|
||||
instance.created_at = datetime.now(timezone.utc)
|
||||
instance.updated_at = datetime.now(timezone.utc)
|
||||
return None
|
||||
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||
|
||||
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once_with(created_item)
|
||||
assert created_item is not None
|
||||
assert created_item.name == item_create_data.name
|
||||
assert created_item.list_id == list_model.id
|
||||
assert created_item.added_by_id == user_model.id
|
||||
assert created_item.is_complete is False
|
||||
assert created_item.version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
|
||||
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
|
||||
with pytest.raises(DatabaseIntegrityError):
|
||||
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
# --- get_items_by_list_id Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.all.return_value = [db_item_model]
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
items = await get_items_by_list_id(mock_db_session, list_model.id)
|
||||
assert len(items) == 1
|
||||
assert items[0].id == db_item_model.id
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# --- get_item_by_id Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_item_by_id_found(mock_db_session, db_item_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_item_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
item = await get_item_by_id(mock_db_session, db_item_model.id)
|
||||
assert item is not None
|
||||
assert item.id == db_item_model.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_item_by_id_not_found(mock_db_session):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
item = await get_item_by_id(mock_db_session, 999)
|
||||
assert item is None
|
||||
|
||||
# --- update_item Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
|
||||
item_update_data.version = db_item_model.version # Match versions for successful update
|
||||
item_update_data.name = "Newly Updated Name"
|
||||
item_update_data.is_complete = True # Test completion logic
|
||||
|
||||
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||
|
||||
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once_with(db_item_model)
|
||||
assert updated_item.name == "Newly Updated Name"
|
||||
assert updated_item.version == db_item_model.version # Check version increment logic in test
|
||||
assert updated_item.is_complete is True
|
||||
assert updated_item.completed_by_id == user_model.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
|
||||
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
|
||||
with pytest.raises(ConflictError):
|
||||
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
|
||||
db_item_model.is_complete = True # Start as complete
|
||||
db_item_model.completed_by_id = user_model.id
|
||||
db_item_model.version = 1
|
||||
|
||||
item_update_data.version = 1
|
||||
item_update_data.is_complete = False
|
||||
item_update_data.name = db_item_model.name # No name change for this test
|
||||
item_update_data.quantity = db_item_model.quantity
|
||||
|
||||
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
|
||||
assert updated_item.is_complete is False
|
||||
assert updated_item.completed_by_id is None
|
||||
assert updated_item.version == 2
|
||||
|
||||
# --- delete_item Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_item_success(mock_db_session, db_item_model):
|
||||
result = await delete_item(mock_db_session, db_item_model)
|
||||
assert result is None
|
||||
mock_db_session.delete.assert_called_once_with(db_item_model)
|
||||
mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_item_db_error(mock_db_session, db_item_model):
|
||||
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
|
||||
with pytest.raises(DatabaseConnectionError):
|
||||
await delete_item(mock_db_session, db_item_model)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function.
|
259
be/tests/crud/test_list.py
Normal file
259
be/tests/crud/test_list.py
Normal file
@ -0,0 +1,259 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import func as sql_func # For get_list_status
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.crud.list import (
|
||||
create_list,
|
||||
get_lists_for_user,
|
||||
get_list_by_id,
|
||||
update_list,
|
||||
delete_list,
|
||||
check_list_permission,
|
||||
get_list_status
|
||||
)
|
||||
from app.schemas.list import ListCreate, ListUpdate, ListStatus
|
||||
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
ListPermissionError,
|
||||
ListCreatorRequiredError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
ConflictError
|
||||
)
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.begin = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock() # Used by check_list_permission via get_list_by_id
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def list_create_data():
|
||||
return ListCreate(name="New Shopping List", description="Groceries for the week")
|
||||
|
||||
@pytest.fixture
|
||||
def list_update_data():
|
||||
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
|
||||
|
||||
@pytest.fixture
|
||||
def user_model():
|
||||
return UserModel(id=1, name="Test User", email="test@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def another_user_model():
|
||||
return UserModel(id=2, name="Another User", email="another@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def group_model():
|
||||
return GroupModel(id=1, name="Test Group")
|
||||
|
||||
@pytest.fixture
|
||||
def db_list_personal_model(user_model):
|
||||
return ListModel(
|
||||
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
|
||||
version=1, updated_at=datetime.now(timezone.utc), items=[]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def db_list_group_model(user_model, group_model):
|
||||
return ListModel(
|
||||
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
|
||||
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
|
||||
)
|
||||
|
||||
# --- create_list Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_list_success(mock_db_session, list_create_data, user_model):
|
||||
async def mock_refresh(instance):
|
||||
instance.id = 100
|
||||
instance.version = 1
|
||||
instance.updated_at = datetime.now(timezone.utc)
|
||||
return None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.refresh.side_effect = mock_refresh
|
||||
|
||||
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once()
|
||||
assert created_list.name == list_create_data.name
|
||||
assert created_list.created_by_id == user_model.id
|
||||
|
||||
# --- get_lists_for_user Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
|
||||
# Simulate user is part of group for db_list_group_model
|
||||
mock_group_ids_result = AsyncMock()
|
||||
mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id]
|
||||
|
||||
mock_lists_result = AsyncMock()
|
||||
# Order should be personal list (created by user_id) then group list
|
||||
mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model]
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result]
|
||||
|
||||
lists = await get_lists_for_user(mock_db_session, user_model.id)
|
||||
assert len(lists) == 2
|
||||
assert db_list_personal_model in lists
|
||||
assert db_list_group_model in lists
|
||||
assert mock_db_session.execute.call_count == 2
|
||||
|
||||
# --- get_list_by_id Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_list_personal_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
|
||||
assert found_list is not None
|
||||
assert found_list.id == db_list_personal_model.id
|
||||
# query options should not include selectinload for items
|
||||
# (difficult to assert directly without inspecting query object in detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
|
||||
# Simulate items loaded for the list
|
||||
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
|
||||
mock_result = AsyncMock()
|
||||
mock_result.scalars.return_value.first.return_value = db_list_personal_model
|
||||
mock_db_session.execute.return_value = mock_result
|
||||
|
||||
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
|
||||
assert found_list is not None
|
||||
assert len(found_list.items) == 1
|
||||
# query options should include selectinload for items
|
||||
|
||||
# --- update_list Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
|
||||
list_update_data.version = db_list_personal_model.version # Match version
|
||||
|
||||
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
||||
assert updated_list.name == list_update_data.name
|
||||
assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model
|
||||
mock_db_session.add.assert_called_once_with(db_list_personal_model)
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once_with(db_list_personal_model)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
|
||||
list_update_data.version = db_list_personal_model.version + 1 # Version mismatch
|
||||
with pytest.raises(ConflictError):
|
||||
await update_list(mock_db_session, db_list_personal_model, list_update_data)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
# --- delete_list Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_list_success(mock_db_session, db_list_personal_model):
|
||||
await delete_list(mock_db_session, db_list_personal_model)
|
||||
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
|
||||
mock_db_session.commit.assert_called_once() # from async with db.begin()
|
||||
|
||||
# --- check_list_permission Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
|
||||
# get_list_by_id (called by check_list_permission) will mock execute
|
||||
mock_list_fetch_result = AsyncMock()
|
||||
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model
|
||||
mock_db_session.execute.return_value = mock_list_fetch_result
|
||||
|
||||
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
|
||||
assert ret_list.id == db_list_personal_model.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
|
||||
# User `another_user_model` is not creator but member of the group
|
||||
db_list_group_model.creator_id = user_model.id # Original creator is user_model
|
||||
db_list_group_model.creator = user_model
|
||||
|
||||
# Mock get_list_by_id internal call
|
||||
mock_list_fetch_result = AsyncMock()
|
||||
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
|
||||
|
||||
# Mock is_user_member call
|
||||
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||
mock_is_member.return_value = True # another_user_model is a member
|
||||
mock_db_session.execute.return_value = mock_list_fetch_result
|
||||
|
||||
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
||||
assert ret_list.id == db_list_group_model.id
|
||||
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
|
||||
db_list_group_model.creator_id = user_model.id # Creator is not another_user_model
|
||||
|
||||
mock_list_fetch_result = AsyncMock()
|
||||
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
|
||||
|
||||
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
|
||||
mock_is_member.return_value = False # another_user_model is NOT a member
|
||||
mock_db_session.execute.return_value = mock_list_fetch_result
|
||||
|
||||
with pytest.raises(ListPermissionError):
|
||||
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
|
||||
mock_list_fetch_result = AsyncMock()
|
||||
mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found
|
||||
mock_db_session.execute.return_value = mock_list_fetch_result
|
||||
|
||||
with pytest.raises(ListNotFoundError):
|
||||
await check_list_permission(mock_db_session, 999, user_model.id)
|
||||
|
||||
# --- get_list_status Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
|
||||
list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1)
|
||||
item_updated_at = datetime.now(timezone.utc)
|
||||
item_count = 5
|
||||
|
||||
db_list_personal_model.updated_at = list_updated_at
|
||||
|
||||
# Mock for ListModel.updated_at query
|
||||
mock_list_updated_result = AsyncMock()
|
||||
mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at
|
||||
|
||||
# Mock for ItemModel status query
|
||||
mock_item_status_result = AsyncMock()
|
||||
# SQLAlchemy query for func.max and func.count returns a Row-like object or None
|
||||
mock_item_status_row = MagicMock()
|
||||
mock_item_status_row.latest_item_updated_at = item_updated_at
|
||||
mock_item_status_row.item_count = item_count
|
||||
mock_item_status_result.first.return_value = mock_item_status_row
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result]
|
||||
|
||||
status = await get_list_status(mock_db_session, db_list_personal_model.id)
|
||||
assert status.list_updated_at == list_updated_at
|
||||
assert status.latest_item_updated_at == item_updated_at
|
||||
assert status.item_count == item_count
|
||||
assert mock_db_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_list_status_list_not_found(mock_db_session):
|
||||
mock_list_updated_result = AsyncMock()
|
||||
mock_list_updated_result.scalar_one_or_none.return_value = None # List not found
|
||||
mock_db_session.execute.return_value = mock_list_updated_result
|
||||
with pytest.raises(ListNotFoundError):
|
||||
await get_list_status(mock_db_session, 999)
|
||||
|
||||
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
|
||||
# TODO: Test check_list_permission with require_creator=True cases.
|
250
be/tests/crud/test_settlement.py
Normal file
250
be/tests/crud/test_settlement.py
Normal file
@ -0,0 +1,250 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.future import select
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timezone
|
||||
from typing import List as PyList
|
||||
|
||||
from app.crud.settlement import (
|
||||
create_settlement,
|
||||
get_settlement_by_id,
|
||||
get_settlements_for_group,
|
||||
get_settlements_involving_user,
|
||||
update_settlement,
|
||||
delete_settlement
|
||||
)
|
||||
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
||||
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
|
||||
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.rollback = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def settlement_create_data():
|
||||
return SettlementCreate(
|
||||
group_id=1,
|
||||
paid_by_user_id=1,
|
||||
paid_to_user_id=2,
|
||||
amount=Decimal("10.50"),
|
||||
settlement_date=datetime.now(timezone.utc),
|
||||
description="Test settlement"
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def settlement_update_data():
|
||||
return SettlementUpdate(
|
||||
description="Updated settlement description",
|
||||
settlement_date=datetime.now(timezone.utc),
|
||||
version=1 # Assuming version is required for update
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def db_settlement_model():
|
||||
return SettlementModel(
|
||||
id=1,
|
||||
group_id=1,
|
||||
paid_by_user_id=1,
|
||||
paid_to_user_id=2,
|
||||
amount=Decimal("10.50"),
|
||||
settlement_date=datetime.now(timezone.utc),
|
||||
description="Original settlement",
|
||||
version=1, # Initial version
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
payer=UserModel(id=1, name="Payer User"),
|
||||
payee=UserModel(id=2, name="Payee User"),
|
||||
group=GroupModel(id=1, name="Test Group")
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def payer_user_model():
|
||||
return UserModel(id=1, name="Payer User", email="payer@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def payee_user_model():
|
||||
return UserModel(id=2, name="Payee User", email="payee@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def group_model():
|
||||
return GroupModel(id=1, name="Test Group")
|
||||
|
||||
# Tests for create_settlement
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
||||
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets
|
||||
|
||||
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once()
|
||||
assert created_settlement is not None
|
||||
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
|
||||
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
|
||||
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
|
||||
with pytest.raises(UserNotFoundError) as excinfo:
|
||||
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||
assert "Payer" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
|
||||
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
|
||||
with pytest.raises(UserNotFoundError) as excinfo:
|
||||
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||
assert "Payee" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
|
||||
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
|
||||
with pytest.raises(GroupNotFoundError):
|
||||
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
|
||||
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
|
||||
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
||||
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
|
||||
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await create_settlement(mock_db_session, settlement_create_data, 1)
|
||||
assert "Failed to save settlement" in str(excinfo.value)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
|
||||
# Tests for get_settlement_by_id
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model
|
||||
settlement = await get_settlement_by_id(mock_db_session, 1)
|
||||
assert settlement is not None
|
||||
assert settlement.id == 1
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settlement_by_id_not_found(mock_db_session):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
settlement = await get_settlement_by_id(mock_db_session, 999)
|
||||
assert settlement is None
|
||||
|
||||
# Tests for get_settlements_for_group
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
|
||||
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
|
||||
assert len(settlements) == 1
|
||||
assert settlements[0].group_id == 1
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# Tests for get_settlements_involving_user
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
|
||||
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
|
||||
assert len(settlements) == 1
|
||||
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
|
||||
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
|
||||
assert len(settlements) == 1
|
||||
# More specific assertions about the query would require deeper mocking of SQLAlchemy query construction
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
|
||||
# Tests for update_settlement
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
|
||||
# Ensure settlement_update_data.version matches db_settlement_model.version
|
||||
settlement_update_data.version = db_settlement_model.version
|
||||
|
||||
# Mock datetime.now()
|
||||
fixed_datetime_now = datetime.now(timezone.utc)
|
||||
with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime:
|
||||
mock_datetime.now.return_value = fixed_datetime_now
|
||||
|
||||
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once()
|
||||
assert updated_settlement.description == settlement_update_data.description
|
||||
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
|
||||
assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented
|
||||
assert updated_settlement.updated_at == fixed_datetime_now
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
|
||||
settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||
assert "version does not match" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
|
||||
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
|
||||
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
|
||||
assert "Field 'amount' cannot be updated" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
|
||||
settlement_update_data.version = db_settlement_model.version
|
||||
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
||||
assert "Failed to update settlement" in str(excinfo.value)
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
# Tests for delete_settlement
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
|
||||
await delete_settlement(mock_db_session, db_settlement_model)
|
||||
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
|
||||
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
|
||||
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1)
|
||||
assert "Expected version" in str(excinfo.value)
|
||||
assert "does not match current version" in str(excinfo.value)
|
||||
mock_db_session.delete.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
|
||||
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
||||
with pytest.raises(InvalidOperationError) as excinfo:
|
||||
await delete_settlement(mock_db_session, db_settlement_model)
|
||||
assert "Failed to delete settlement" in str(excinfo.value)
|
||||
mock_db_session.rollback.assert_called_once()
|
117
be/tests/crud/test_user.py
Normal file
117
be/tests/crud/test_user.py
Normal file
@ -0,0 +1,117 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
||||
from app.crud.user import get_user_by_email, create_user
|
||||
from app.schemas.user import UserCreate
|
||||
from app.models import User as UserModel
|
||||
from app.core.exceptions import (
|
||||
UserCreationError,
|
||||
EmailAlreadyRegisteredError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError
|
||||
)
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def user_create_data():
|
||||
return UserCreate(email="test@example.com", password="password123", name="Test User")
|
||||
|
||||
@pytest.fixture
|
||||
def existing_user_data():
|
||||
return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User")
|
||||
|
||||
# Tests for get_user_by_email
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data
|
||||
user = await get_user_by_email(mock_db_session, "exists@example.com")
|
||||
assert user is not None
|
||||
assert user.email == "exists@example.com"
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_not_found(mock_db_session):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
|
||||
assert user is None
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_db_connection_error(mock_db_session):
|
||||
mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||
with pytest.raises(DatabaseConnectionError):
|
||||
await get_user_by_email(mock_db_session, "test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_db_query_error(mock_db_session):
|
||||
# Simulate a generic SQLAlchemyError that is not OperationalError
|
||||
mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError
|
||||
with pytest.raises(DatabaseQueryError):
|
||||
await get_user_by_email(mock_db_session, "test@example.com")
|
||||
|
||||
|
||||
# Tests for create_user
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(mock_db_session, user_create_data):
|
||||
# The actual user object returned would be created by SQLAlchemy based on db_user
|
||||
# We mock the process: db.add is called, then db.flush, then db.refresh updates db_user
|
||||
async def mock_refresh(user_model_instance):
|
||||
user_model_instance.id = 1 # Simulate DB assigning an ID
|
||||
# Simulate other db-generated fields if necessary
|
||||
return None
|
||||
|
||||
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
|
||||
mock_db_session.flush = AsyncMock()
|
||||
mock_db_session.add = MagicMock()
|
||||
|
||||
created_user = await create_user(mock_db_session, user_create_data)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
mock_db_session.refresh.assert_called_once()
|
||||
|
||||
assert created_user is not None
|
||||
assert created_user.email == user_create_data.email
|
||||
assert created_user.name == user_create_data.name
|
||||
assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh)
|
||||
# Password hash check would be more involved, ensure hash_password was called correctly
|
||||
# For now, we assume hash_password works as intended and is tested elsewhere.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_email_already_registered(mock_db_session, user_create_data):
|
||||
mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig")
|
||||
with pytest.raises(EmailAlreadyRegisteredError):
|
||||
await create_user(mock_db_session, user_create_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data):
|
||||
# Simulate an IntegrityError that is not related to a unique constraint
|
||||
mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig")
|
||||
with pytest.raises(DatabaseIntegrityError):
|
||||
await create_user(mock_db_session, user_create_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_db_connection_error(mock_db_session, user_create_data):
|
||||
mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||
with pytest.raises(DatabaseConnectionError):
|
||||
await create_user(mock_db_session, user_create_data)
|
||||
# also test OperationalError on flush
|
||||
mock_db_session.begin.side_effect = None # reset side effect
|
||||
mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig")
|
||||
with pytest.raises(DatabaseConnectionError):
|
||||
await create_user(mock_db_session, user_create_data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_db_transaction_error(mock_db_session, user_create_data):
|
||||
# Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError
|
||||
mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError
|
||||
with pytest.raises(DatabaseTransactionError):
|
||||
await create_user(mock_db_session, user_create_data)
|
Loading…
Reference in New Issue
Block a user