This commit is contained in:
mohamad 2025-05-08 00:56:26 +02:00
parent 423d345fdf
commit bbb3c3b7df
31 changed files with 4998 additions and 245 deletions

View File

@ -0,0 +1,28 @@
"""add_version_to_settlements
Revision ID: 071ac4268ccb
Revises: be770eea8ec2
Create Date: 2025-05-07 23:47:27.788572
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '071ac4268ccb'
down_revision: Union[str, None] = 'be770eea8ec2'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -0,0 +1,28 @@
"""add_version_to_lists_table
Revision ID: 64a6614cb156
Revises: 071ac4268ccb
Create Date: 2025-05-08 00:48:49.027570
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '64a6614cb156'
down_revision: Union[str, None] = '071ac4268ccb'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -0,0 +1,28 @@
"""add_expense_split_settlement_tables_and_relations
Revision ID: 8c2c0f83e2b9
Revises: d53eedd151b7
Create Date: 2025-05-07 23:30:48.621512
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '8c2c0f83e2b9'
down_revision: Union[str, None] = 'd53eedd151b7'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -0,0 +1,28 @@
"""add_version_to_settlements
Revision ID: be770eea8ec2
Revises: 8c2c0f83e2b9
Create Date: 2025-05-07 23:41:26.669049
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'be770eea8ec2'
down_revision: Union[str, None] = '8c2c0f83e2b9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
pass
def downgrade() -> None:
"""Downgrade schema."""
pass

View File

@ -1,4 +1,3 @@
# app/api/v1/api.py
from fastapi import APIRouter
from app.api.v1.endpoints import health
@ -10,10 +9,11 @@ from app.api.v1.endpoints import lists
from app.api.v1.endpoints import items
from app.api.v1.endpoints import ocr
from app.api.v1.endpoints import costs
from app.api.v1.endpoints import financials
api_router_v1 = APIRouter()
api_router_v1.include_router(health.router) # Path /health defined inside
api_router_v1.include_router(health.router)
api_router_v1.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])
api_router_v1.include_router(groups.router, prefix="/groups", tags=["Groups"])
@ -22,5 +22,6 @@ api_router_v1.include_router(lists.router, prefix="/lists", tags=["Lists"])
api_router_v1.include_router(items.router, tags=["Items"])
api_router_v1.include_router(ocr.router, prefix="/ocr", tags=["OCR"])
api_router_v1.include_router(costs.router, tags=["Costs"])
api_router_v1.include_router(financials.router)
# Add other v1 endpoint routers here later
# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"])

View File

@ -1,15 +1,17 @@
# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel # For get_current_user dependency
from app.schemas.cost import ListCostSummary
from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.crud import cost as crud_cost
from app.crud import list as crud_list # For permission checking
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
logger = logging.getLogger(__name__)
router = APIRouter()
@ -66,4 +68,61 @@ async def get_list_cost_summary(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
@router.get(
"/groups/{group_id}/balance-summary",
response_model=GroupBalanceSummary,
summary="Get Detailed Balance Summary for a Group",
tags=["Costs", "Groups"],
responses={
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
}
)
async def get_group_balance_summary(
group_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Retrieves a detailed financial balance summary for all users within a specific group.
It considers all expenses, their splits, and all settlements recorded for the group.
The user must be a member of the group to view its balance summary.
"""
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
# 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar)
# Assuming a function like this exists in app.crud.group or we add it.
# For now, let's placeholder this check logic.
# await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id)
# A simpler check for now: fetch the group and see if user is part of member_associations
group_check = await db.execute(
select(GroupModel)
.options(selectinload(GroupModel.member_associations))
.where(GroupModel.id == group_id)
)
db_group_for_check = group_check.scalars().first()
if not db_group_for_check:
raise GroupNotFoundError(group_id)
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
if not user_is_member:
# If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
# 2. Calculate the group balance summary
try:
balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id)
logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}")
return balance_summary
except GroupNotFoundError as e:
logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e: # Should not happen if group members are correctly fetched
logger.error(f"User not found during balance summary for group {group_id}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.")
except Exception as e:
logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.")

View File

@ -0,0 +1,442 @@
# app/api/v1/endpoints/financials.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import List as PyList, Optional, Sequence
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum
from app.schemas.expense import (
ExpenseCreate, ExpensePublic,
SettlementCreate, SettlementPublic,
ExpenseUpdate, SettlementUpdate
)
from app.crud import expense as crud_expense
from app.crud import settlement as crud_settlement
from app.crud import group as crud_group
from app.crud import list as crud_list
from app.core.exceptions import (
ListNotFoundError, GroupNotFoundError, UserNotFoundError,
InvalidOperationError, GroupPermissionError, ListPermissionError,
ItemNotFoundError, GroupMembershipError
)
logger = logging.getLogger(__name__)
router = APIRouter()
# --- Helper for permissions ---
async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_id: int, action: str = "access financial data for"):
try:
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=user_id, require_member=True)
except ListPermissionError as e:
logger.warning(f"ListPermissionError in check_list_access_for_financials for list {list_id}, user {user_id}, action '{action}': {e.detail}")
raise ListPermissionError(list_id, action=action)
except ListNotFoundError:
raise
# --- Expense Endpoints ---
@router.post(
"/expenses",
response_model=ExpensePublic,
status_code=status.HTTP_201_CREATED,
summary="Create New Expense",
tags=["Expenses"]
)
async def create_new_expense(
expense_in: ExpenseCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
effective_group_id = expense_in.group_id
is_group_context = False
if expense_in.list_id:
# Check basic access to list (implies membership if list is in group)
await check_list_access_for_financials(db, expense_in.list_id, current_user.id, action="create expenses for")
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(f"List {list_obj.id} belongs to group {list_obj.group_id}, not group {expense_in.group_id} specified in expense.")
effective_group_id = list_obj.group_id
is_group_context = True # Expense is tied to a group via the list
elif expense_in.group_id:
raise InvalidOperationError(f"Personal list {list_obj.id} cannot have expense associated with group {expense_in.group_id}.")
# If list is personal, no group check needed yet, handled by payer check below.
elif effective_group_id: # Only group_id provided for expense
is_group_context = True
# Ensure user is at least a member to create expense in group context
await crud_group.check_group_membership(db, group_id=effective_group_id, user_id=current_user.id, action="create expenses for")
else:
# This case should ideally be caught by earlier checks if list_id was present but list was personal.
# If somehow reached, it means no list_id and no group_id.
raise InvalidOperationError("Expense must be linked to a list_id or group_id.")
# Finalize expense payload with correct group_id if derived
expense_in_final = expense_in.model_copy(update={"group_id": effective_group_id})
# --- Granular Permission Check for Payer ---
if expense_in_final.paid_by_user_id != current_user.id:
logger.warning(f"User {current_user.email} attempting to create expense paid by other user {expense_in_final.paid_by_user_id}")
# If creating expense paid by someone else, user MUST be owner IF in group context
if is_group_context and effective_group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=effective_group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="create expense paid by another user")
except GroupPermissionError as e:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Only group owners can create expenses paid by others. {str(e)}")
else:
# Cannot create expense paid by someone else for a personal list (no group context)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Cannot create expense paid by another user for a personal list.")
# If paying for self, basic list/group membership check above is sufficient.
try:
created_expense = await crud_expense.create_expense(db=db, expense_in=expense_in_final, current_user_id=current_user.id)
logger.info(f"Expense '{created_expense.description}' (ID: {created_expense.id}) created successfully.")
return created_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except NotImplementedError as e:
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error creating expense: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
async def get_expense(
expense_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
expense = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense:
raise ItemNotFoundError(item_id=expense_id)
if expense.list_id:
await check_list_access_for_financials(db, expense.list_id, current_user.id)
elif expense.group_id:
await crud_group.check_group_membership(db, group_id=expense.group_id, user_id=current_user.id)
elif expense.paid_by_user_id != current_user.id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view this expense")
return expense
@router.get("/lists/{list_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a List", tags=["Expenses", "Lists"])
async def list_list_expenses(
list_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing expenses for list ID {list_id}")
await check_list_access_for_financials(db, list_id, current_user.id)
expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit)
return expenses
@router.get("/groups/{group_id}/expenses", response_model=PyList[ExpensePublic], summary="List Expenses for a Group", tags=["Expenses", "Groups"])
async def list_group_expenses(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list expenses for")
expenses = await crud_expense.get_expenses_for_group(db, group_id=group_id, skip=skip, limit=limit)
return expenses
@router.put("/expenses/{expense_id}", response_model=ExpensePublic, summary="Update Expense", tags=["Expenses"])
async def update_expense_details(
expense_id: int,
expense_in: ExpenseUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Updates an existing expense (description, currency, expense_date only).
Requires the current version number for optimistic locking.
User must have permission to modify the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to update expense ID {expense_id} (version {expense_in.version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
raise ItemNotFoundError(item_id=expense_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_modify = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group expenses")
can_modify = True
logger.info(f"Allowing update for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError: # User not even a member
pass # Keep can_modify as False
except GroupPermissionError: # User is member but not owner
pass # Keep can_modify as False
except GroupNotFoundError: # Group doesn't exist (data integrity issue)
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during update check.")
pass # Keep can_modify as False
# Note: If expense is only linked to a personal list (no group), only payer can modify.
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this expense (must be payer or group owner)")
try:
updated_expense = await crud_expense.update_expense(db=db, expense_db=expense_db, expense_in=expense_in)
logger.info(f"Expense ID {expense_id} updated successfully to version {updated_expense.version}.")
return updated_expense
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/expenses/{expense_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Expense", tags=["Expenses"])
async def delete_expense_record(
expense_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes an expense and its associated splits.
Requires expected_version query parameter for optimistic locking.
User must have permission to delete the expense (e.g., be the payer or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete expense ID {expense_id} (expected version {expected_version})")
expense_db = await crud_expense.get_expense_by_id(db, expense_id=expense_id)
if not expense_db:
# Return 204 even if not found, as the end state is achieved (item is gone)
logger.warning(f"Attempt to delete non-existent expense ID {expense_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# Alternatively, raise NotFoundError(detail=f"Expense {expense_id} not found") -> 404
# --- Granular Permission Check ---
can_delete = False
# 1. User paid for the expense
if expense_db.paid_by_user_id == current_user.id:
can_delete = True
# 2. OR User is owner of the group the expense belongs to
elif expense_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=expense_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group expenses")
can_delete = True
logger.info(f"Allowing delete for expense {expense_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {expense_db.group_id} not found for expense {expense_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this expense (must be payer or group owner)")
try:
await crud_expense.delete_expense(db=db, expense_db=expense_db, expected_version=expected_version)
logger.info(f"Expense ID {expense_id} deleted successfully.")
# No need to return content on 204
except InvalidOperationError as e:
# Check if it's a version conflict (409) or other validation error (400)
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting expense {expense_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Settlement Endpoints ---
@router.post(
"/settlements",
response_model=SettlementPublic,
status_code=status.HTTP_201_CREATED,
summary="Record New Settlement",
tags=["Settlements"]
)
async def create_new_settlement(
settlement_in: SettlementCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=current_user.id, action="record settlements in")
try:
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_by_user_id, action="be a payer in this group's settlement")
await crud_group.check_group_membership(db, group_id=settlement_in.group_id, user_id=settlement_in.paid_to_user_id, action="be a payee in this group's settlement")
except GroupMembershipError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Payer or payee issue: {str(e)}")
except GroupNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
try:
created_settlement = await crud_settlement.create_settlement(db=db, settlement_in=settlement_in, current_user_id=current_user.id)
logger.info(f"Settlement ID {created_settlement.id} recorded successfully in group {settlement_in.group_id}.")
return created_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError, GroupMembershipError) as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error recording settlement: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
async def get_settlement(
settlement_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
settlement = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement:
raise ItemNotFoundError(item_id=settlement_id)
is_party_to_settlement = current_user.id in [settlement.paid_by_user_id, settlement.paid_to_user_id]
try:
await crud_group.check_group_membership(db, group_id=settlement.group_id, user_id=current_user.id)
except GroupMembershipError:
if not is_party_to_settlement:
raise GroupMembershipError(settlement.group_id, action="view this settlement's details")
logger.info(f"User {current_user.email} (party to settlement) viewing settlement {settlement_id} for group {settlement.group_id}.")
return settlement
@router.get("/groups/{group_id}/settlements", response_model=PyList[SettlementPublic], summary="List Settlements for a Group", tags=["Settlements", "Groups"])
async def list_group_settlements(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
await crud_group.check_group_membership(db, group_id=group_id, user_id=current_user.id, action="list settlements for this group")
settlements = await crud_settlement.get_settlements_for_group(db, group_id=group_id, skip=skip, limit=limit)
return settlements
@router.put("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Update Settlement", tags=["Settlements"])
async def update_settlement_details(
settlement_id: int,
settlement_in: SettlementUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Updates an existing settlement (description, settlement_date only).
Requires the current version number for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to update settlement ID {settlement_id} (version {settlement_in.version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
raise ItemNotFoundError(item_id=settlement_id)
# --- Granular Permission Check ---
can_modify = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_modify = True
# 2. OR User is owner of the group the settlement belongs to
# Note: Settlements always have a group_id based on current model
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="modify group settlements")
can_modify = True
logger.info(f"Allowing update for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during update check.")
pass
if not can_modify:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot modify this settlement (must be involved party or group owner)")
try:
updated_settlement = await crud_settlement.update_settlement(db=db, settlement_db=settlement_db, settlement_in=settlement_in)
logger.info(f"Settlement ID {settlement_id} updated successfully to version {updated_settlement.version}.")
return updated_settlement
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error updating settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
@router.delete("/settlements/{settlement_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Settlement", tags=["Settlements"])
async def delete_settlement_record(
settlement_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user),
):
"""
Deletes a settlement.
Requires expected_version query parameter for optimistic locking.
User must have permission (e.g., be involved party or group admin).
"""
logger.info(f"User {current_user.email} attempting to delete settlement ID {settlement_id} (expected version {expected_version})")
settlement_db = await crud_settlement.get_settlement_by_id(db, settlement_id=settlement_id)
if not settlement_db:
logger.warning(f"Attempt to delete non-existent settlement ID {settlement_id}")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# --- Granular Permission Check ---
can_delete = False
# 1. User is involved party (payer or payee)
is_party = current_user.id in [settlement_db.paid_by_user_id, settlement_db.paid_to_user_id]
if is_party:
can_delete = True
# 2. OR User is owner of the group the settlement belongs to
elif settlement_db.group_id:
try:
await crud_group.check_user_role_in_group(db, group_id=settlement_db.group_id, user_id=current_user.id, required_role=UserRoleEnum.owner, action="delete group settlements")
can_delete = True
logger.info(f"Allowing delete for settlement {settlement_id} by group owner {current_user.email}")
except GroupMembershipError:
pass
except GroupPermissionError:
pass
except GroupNotFoundError:
logger.error(f"Group {settlement_db.group_id} not found for settlement {settlement_id} during delete check.")
pass
if not can_delete:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User cannot delete this settlement (must be involved party or group owner)")
try:
await crud_settlement.delete_settlement(db=db, settlement_db=settlement_db, expected_version=expected_version)
logger.info(f"Settlement ID {settlement_id} deleted successfully.")
except InvalidOperationError as e:
status_code = status.HTTP_400_BAD_REQUEST
if "version" in str(e).lower():
status_code = status.HTTP_409_CONFLICT
raise HTTPException(status_code=status_code, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error deleting settlement {settlement_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.")
return Response(status_code=status.HTTP_204_NO_CONTENT)
# TODO (remaining from original list):
# (None - GET/POST/PUT/DELETE implemented for Expense/Settlement)

View File

@ -88,6 +88,14 @@ class UserNotFoundError(HTTPException):
detail=detail_msg
)
class InvalidOperationError(HTTPException):
"""Raised when an operation is invalid or disallowed by business logic."""
def __init__(self, detail: str, status_code: int = status.HTTP_400_BAD_REQUEST):
super().__init__(
status_code=status_code,
detail=detail
)
class DatabaseConnectionError(HTTPException):
"""Raised when there is an error connecting to the database."""
def __init__(self):

View File

@ -1,83 +0,0 @@
# be/tests/core/test_gemini.py
import pytest
import os
from unittest.mock import patch, MagicMock
# Store original key if exists, then clear it for testing missing key scenario
original_api_key = os.environ.get("GEMINI_API_KEY")
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
# --- Test Module Import ---
# This forces the module-level initialization code in gemini.py to run
# We need to reload modules because settings might have been cached
from importlib import reload
from app.config import settings as app_settings
from app.core import gemini as gemini_core
# Reload settings first to ensure GEMINI_API_KEY is None initially
reload(app_settings)
# Reload gemini core to trigger initialization logic with potentially missing key
reload(gemini_core)
def test_gemini_initialization_without_key():
"""Verify behavior when GEMINI_API_KEY is not set."""
# Reload modules again to ensure clean state for this specific test
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
reload(app_settings)
reload(gemini_core)
assert gemini_core.gemini_flash_client is None
assert gemini_core.gemini_initialization_error is not None
assert "GEMINI_API_KEY not configured" in gemini_core.gemini_initialization_error
with pytest.raises(RuntimeError, match="GEMINI_API_KEY not configured"):
gemini_core.get_gemini_client()
@patch('google.generativeai.configure')
@patch('google.generativeai.GenerativeModel')
def test_gemini_initialization_with_key(mock_generative_model: MagicMock, mock_configure: MagicMock):
"""Verify initialization logic is called when key is present (using mocks)."""
# Set a dummy key in the environment for this test
test_key = "TEST_API_KEY_123"
os.environ["GEMINI_API_KEY"] = test_key
# Reload settings and gemini module to pick up the new key
reload(app_settings)
reload(gemini_core)
# Assertions
mock_configure.assert_called_once_with(api_key=test_key)
mock_generative_model.assert_called_once_with(
model_name="gemini-1.5-flash-latest",
safety_settings=pytest.ANY, # Check safety settings were passed (ANY allows flexibility)
# generation_config=pytest.ANY # Check if you added default generation config
)
assert gemini_core.gemini_flash_client is not None
assert gemini_core.gemini_initialization_error is None
# Test get_gemini_client() success path
client = gemini_core.get_gemini_client()
assert client is not None # Should return the mocked client instance
# Clean up environment variable after test
if original_api_key:
os.environ["GEMINI_API_KEY"] = original_api_key
else:
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
# Reload modules one last time to restore state for other tests
reload(app_settings)
reload(gemini_core)
# Restore original key after all tests in the module run (if needed)
def teardown_module(module):
if original_api_key:
os.environ["GEMINI_API_KEY"] = original_api_key
else:
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
reload(app_settings)
reload(gemini_core)

View File

@ -1,86 +0,0 @@
# Example: be/tests/core/test_security.py
import pytest
from datetime import timedelta
from jose import jwt, JWTError
import time
from app.core.security import (
hash_password,
verify_password,
create_access_token,
verify_access_token,
)
from app.config import settings # Import settings for testing JWT config
# --- Password Hashing Tests ---
def test_hash_password_returns_string():
password = "testpassword"
hashed = hash_password(password)
assert isinstance(hashed, str)
assert password != hashed # Ensure it's not plain text
def test_verify_password_correct():
password = "correct_password"
hashed = hash_password(password)
assert verify_password(password, hashed) is True
def test_verify_password_incorrect():
hashed = hash_password("correct_password")
assert verify_password("wrong_password", hashed) is False
def test_verify_password_invalid_hash_format():
# Passlib's verify handles many format errors gracefully
assert verify_password("any_password", "invalid_hash_string") is False
# --- JWT Tests ---
def test_create_access_token():
subject = "testuser@example.com"
token = create_access_token(subject=subject)
assert isinstance(token, str)
# Decode manually for basic check (verification done in verify_access_token tests)
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == subject
assert "exp" in payload
assert isinstance(payload["exp"], int)
def test_verify_access_token_valid():
subject = "test_subject_valid"
token = create_access_token(subject=subject)
payload = verify_access_token(token)
assert payload is not None
assert payload["sub"] == subject
def test_verify_access_token_invalid_signature():
subject = "test_subject_invalid_sig"
token = create_access_token(subject=subject)
# Attempt to verify with a wrong key
wrong_key = settings.SECRET_KEY + "wrong"
with pytest.raises(JWTError): # Decoding with wrong key should raise JWTError internally
jwt.decode(token, wrong_key, algorithms=[settings.ALGORITHM])
# Our verify function should catch this and return None
assert verify_access_token(token + "tamper") is None # Tampering token often invalidates sig
# Note: Testing verify_access_token directly returning None for wrong key is tricky
# as the error happens *during* jwt.decode. We rely on it catching JWTError.
def test_verify_access_token_expired():
# Create a token that expires almost immediately
subject = "test_subject_expired"
expires_delta = timedelta(seconds=-1) # Expired 1 second ago
token = create_access_token(subject=subject, expires_delta=expires_delta)
# Wait briefly just in case of timing issues, though negative delta should guarantee expiry
time.sleep(0.1)
# Decoding expired token raises ExpiredSignatureError internally
with pytest.raises(JWTError): # Specifically ExpiredSignatureError, but JWTError catches it
jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
# Our verify function should catch this and return None
assert verify_access_token(token) is None
def test_verify_access_token_malformed():
assert verify_access_token("this.is.not.a.valid.token") is None

View File

@ -1,24 +1,49 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import func as sql_func, or_
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict, Set
from typing import Dict, Optional, Sequence, List as PyList
from collections import defaultdict
import logging
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
from app.schemas.cost import ListCostSummary, UserCostShare
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
UserGroup as UserGroupModel,
Group as GroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.core.exceptions import (
ListNotFoundError,
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError
)
logger = logging.getLogger(__name__)
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list, splitting costs equally
among relevant users (group members if list is in a group, or creator if personal).
Calculates the cost summary for a given list based purely on item prices and who added them.
This is a simpler calculation and does not involve the Expense/Settlement system.
"""
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
@ -27,90 +52,215 @@ async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCos
if not db_list:
raise ListNotFoundError(list_id)
# 2. Determine participating users
participating_users: Dict[int, UserModel] = {}
participating_users_map: Dict[int, UserModel] = {}
if db_list.group:
# If list is part of a group, all group members participate
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user: # Ensure user object is loaded
participating_users[ug_assoc.user.id] = ug_assoc.user
else:
# If personal list, only the creator participates (or if items were added by others somehow, include them)
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
creator_user = await db.get(UserModel, db_list.created_by_id)
if not creator_user:
# This case should ideally not happen if data integrity is maintained
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
participating_users[creator_user.id] = creator_user
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
if ug_assoc.user:
participating_users_map[ug_assoc.user.id] = ug_assoc.user
elif db_list.creator: # Personal list
participating_users_map[db_list.creator.id] = db_list.creator
# Include all users who added items with prices, even if not in the primary context (group/creator)
for item in db_list.items:
if item.added_by_user and item.added_by_user.id not in participating_users:
participating_users[item.added_by_user.id] = item.added_by_user
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
participating_users_map[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users)
num_participating_users = len(participating_users_map)
if num_participating_users == 0:
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
)
# 3. Calculate total cost and items_added_value for each user
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
if item.added_by_id in participating_users_map:
user_items_added_value[item.added_by_id] += item.price
# If item.added_by_id is not in participating_users (e.g. user left group),
# their contribution still counts to total cost, but they aren't part of the split.
# The current logic adds item adders to participating_users, so this else is less likely.
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
# 4. Calculate equal share per user
# Using ROUND_HALF_UP to handle cents appropriately.
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
# 5. For each user, calculate their balance
user_balances: PyList[UserCostShare] = []
for user_id, user_obj in participating_users.items():
first_user_processed = False
for user_id, user_obj in participating_users_map.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
balance = items_added - equal_share_per_user
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
balance = items_added - current_user_share
user_identifier = user_obj.name if user_obj.name else user_obj.email
user_balances.append(
UserCostShare(
user_id=user_id,
user_identifier=user_identifier,
user_id=user_id, user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=equal_share_per_user, # Already quantized
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
# Sort user_balances for consistent output, e.g., by user_id or identifier
user_balances.sort(key=lambda x: x.user_identifier)
# 6. Return the populated ListCostSummary schema
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
list_id=db_list.id, list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user, # Already quantized
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
# --- Helper for Settlement Suggestions ---
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
"""
Calculates a list of suggested settlements to resolve group debts.
Uses a greedy algorithm to minimize the number of transactions.
Input: List of UserBalanceDetail objects with calculated net_balances.
Output: List of SuggestedSettlement objects.
"""
# Use a small tolerance for floating point comparisons with Decimal
tolerance = Decimal("0.001")
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
settlements: PyList[SuggestedSettlement] = []
debtor_idx = 0
creditor_idx = 0
# Create mutable copies of balances to track remaining amounts
debtor_balances = {d.user_id: d.net_balance for d in debtors}
creditor_balances = {c.user_id: c.net_balance for c in creditors}
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
debtor = debtors[debtor_idx]
creditor = creditors[creditor_idx]
debtor_remaining = debtor_balances[debtor.user_id]
creditor_remaining = creditor_balances[creditor.user_id]
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
if transfer_amount > tolerance: # Only record meaningful transfers
settlements.append(SuggestedSettlement(
from_user_id=debtor.user_id,
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
to_user_id=creditor.user_id,
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
amount=transfer_amount
))
# Update remaining balances
debtor_balances[debtor.user_id] += transfer_amount
creditor_balances[creditor.user_id] -= transfer_amount
# Move to next debtor if current one is settled (or very close)
if abs(debtor_balances[debtor.user_id]) < tolerance:
debtor_idx += 1
# Move to next creditor if current one is settled (or very close)
if creditor_balances[creditor.user_id] < tolerance:
creditor_idx += 1
# Log if lists aren't empty - indicates potential imbalance or rounding issue
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
# Calculate remaining balances for logging
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
return settlements
# --- NEW: Detailed Group Balance Summary ---
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
"""
Calculates a detailed balance summary for all users in a group,
considering all expenses, splits, and settlements within that group.
Also calculates suggested settlements.
"""
group = await db.get(GroupModel, group_id)
if not group:
raise GroupNotFoundError(group_id)
# 1. Get all group members
group_members_result = await db.execute(
select(UserModel)
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
.where(UserGroupModel.group_id == group_id)
)
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
if not group_members:
return GroupBalanceSummary(
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
)
user_balances_data: Dict[int, UserBalanceDetail] = {}
for user_id, user_obj in group_members.items():
user_balances_data[user_id] = UserBalanceDetail(
user_id=user_id,
user_identifier=user_obj.name if user_obj.name else user_obj.email
)
overall_total_expenses = Decimal("0.00")
overall_total_settlements = Decimal("0.00")
# 2. Process Expenses and ExpenseSplits for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits))
)
for expense in expenses_result.scalars().all():
overall_total_expenses += expense.total_amount
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# 3. Process Settlements for the group
settlements_result = await db.execute(
select(SettlementModel).where(SettlementModel.group_id == group_id)
)
for settlement in settlements_result.scalars().all():
overall_total_settlements += settlement.amount
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# 4. Calculate net balances and prepare final list
final_user_balances: PyList[UserBalanceDetail] = []
for user_id in group_members.keys():
data = user_balances_data[user_id]
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
final_user_balances.sort(key=lambda x: x.user_identifier)
# 5. Calculate suggested settlements (NEW)
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=group.id,
group_name=group.name,
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=final_user_balances,
suggested_settlements=suggested_settlements # Add suggestions to response
)

592
be/app/crud/expense.py Normal file
View File

@ -0,0 +1,592 @@
# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
from datetime import datetime, timezone # Added timezone
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
Item as ItemModel
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
# Using existing specific exceptions where possible
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError # Import the new exception
)
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
# pass
logger = logging.getLogger(__name__) # Initialize logger
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
"""
Determines the list of users an expense should be split amongst.
Priority: Group members (if group_id), then List's group members or creator (if list_id).
Fallback to only the payer if no other context yields users.
"""
users_to_split_with: PyList[UserModel] = []
processed_user_ids = set()
async def _add_user(user: Optional[UserModel]):
if user and user.id not in processed_user_ids:
users_to_split_with.append(user)
processed_user_ids.add(user.id)
if expense_group_id:
group_result = await db.execute(
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
.where(GroupModel.id == expense_group_id)
)
group = group_result.scalars().first()
if not group:
raise GroupNotFoundError(expense_group_id)
for assoc in group.member_associations:
await _add_user(assoc.user)
elif expense_list_id: # Only if group_id was not primary context
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == expense_list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(expense_list_id)
if db_list.group:
for assoc in db_list.group.member_associations:
await _add_user(assoc.user)
elif db_list.creator:
await _add_user(db_list.creator)
if not users_to_split_with:
payer_user = await db.get(UserModel, expense_paid_by_user_id)
if not payer_user:
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
raise UserNotFoundError(user_id=expense_paid_by_user_id)
await _add_user(payer_user)
if not users_to_split_with:
# This should ideally not be reached if payer is always a fallback
raise InvalidOperationError("Could not determine any users for splitting the expense.")
return users_to_split_with
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
"""Creates a new expense and its associated splits.
Args:
db: Database session
expense_in: Expense creation data
current_user_id: ID of the user creating the expense
Returns:
The created expense with splits
Raises:
UserNotFoundError: If payer or split users don't exist
ListNotFoundError: If specified list doesn't exist
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
# Helper function to round decimals consistently
def round_money(amount: Decimal) -> Decimal:
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# 1. Context Validation
# Validate basic context requirements first
if not expense_in.list_id and not expense_in.group_id:
raise InvalidOperationError("Expense must be associated with a list or a group.")
# 2. User Validation
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
# 3. List/Group Context Resolution
final_group_id = await _resolve_expense_context(db, expense_in)
# 4. Create the expense object
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id,
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id # Track who created this expense
)
# 5. Generate splits based on split type
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
# 6. Single transaction for expense and all splits
try:
db.add(db_expense)
await db.flush() # Get expense ID without committing
# Update all splits with the expense ID
for split in splits_to_create:
split.expense_id = db_expense.id
db.add_all(splits_to_create)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
raise InvalidOperationError(f"Failed to save expense: {str(e)}")
# Refresh to get the splits relationship populated
await db.refresh(db_expense, attribute_names=["splits"])
return db_expense
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
"""Resolves and validates the expense's context (list and group).
Returns the final group_id for the expense after validation.
"""
final_group_id = expense_in.group_id
# If list_id is provided, validate it and potentially derive group_id
if expense_in.list_id:
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
# If list belongs to a group, verify consistency or inherit group_id
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
f"but expense was specified for group {expense_in.group_id}."
)
final_group_id = list_obj.group_id # Prioritize list's group
# If only group_id is provided (no list_id), validate group_id
elif final_group_id:
group_obj = await db.get(GroupModel, final_group_id)
if not group_obj:
raise GroupNotFoundError(final_group_id)
return final_group_id
async def _generate_expense_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(
db, db_expense, expense_in, round_money
)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
if not splits_to_create:
raise InvalidOperationError("No expense splits were generated.")
return splits_to_create
async def _create_equal_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
remainder = db_expense.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
owed_amount=split_amount
))
return splits
async def _create_exact_amount_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=rounded_amount
))
if round_money(current_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
)
return splits
async def _create_percentage_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
total_percentage = Decimal("0.00")
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
raise InvalidOperationError(
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
)
total_percentage += split_in.share_percentage
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_percentage=split_in.share_percentage
))
if round_money(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for SHARES split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
# Calculate total shares
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
if total_shares == 0:
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
splits = []
current_total = Decimal("0.00")
for split_in in expense_in.splits_in:
if not (split_in.share_units and split_in.share_units > 0):
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money(db_expense.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_units=split_in.share_units
))
# Adjust for rounding differences
if current_total != db_expense.total_amount and splits:
diff = db_expense.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_in.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
if expense_in.item_id:
items_query = items_query.where(ItemModel.id == expense_in.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
# Load items with their adders
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
relevant_items = items_result.scalars().all()
if not relevant_items:
error_msg = (
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
if expense_in.item_id else
f"List {expense_in.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
# Aggregate owed amounts by user
calculated_total = Decimal("0.00")
user_owed_amounts = defaultdict(Decimal)
processed_items = 0
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_in.item_id:
raise InvalidOperationError(
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
)
continue
if not item.added_by_user:
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
calculated_total += item.price
user_owed_amounts[item.added_by_user.id] += item.price
processed_items += 1
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money(calculated_total) != db_expense.total_amount:
raise InvalidOperationError(
f"Expense total amount ({db_expense.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
# Create splits based on aggregated amounts
splits = []
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money(owed_amount)
))
return splits
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
"""Validates that all users in the splits exist."""
user_ids_in_split = [s.user_id for s in splits_in]
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
found_user_ids = {row[0] for row in user_results}
if len(found_user_ids) != len(user_ids_in_split):
missing_user_ids = set(user_ids_in_split) - found_user_ids
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item)
)
.where(ExpenseModel.id == expense_id)
)
return result.scalars().first()
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
)
return result.scalars().all()
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
)
return result.scalars().all()
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
if expense_db.version != expense_in.version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
)
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
# Fields that are safe to update without affecting splits or core logic
allowed_to_update = {"description", "currency", "expense_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(expense_db, field, value)
updated_something = True
else:
# If any other field is present in the update payload, it's an invalid operation for this simple update
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
# No actual updatable fields were provided in the payload, even if others (like version) were.
# This could be a non-issue, or an indication of a misuse of the endpoint.
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
try:
await db.commit()
await db.refresh(expense_db)
except Exception as e:
await db.rollback()
# Consider specific DB error types if needed
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
return expense_db
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
Deletes an expense. Requires version matching if expected_version is provided.
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
"""
if expected_version is not None and expense_db.version != expected_version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(expense_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.

View File

@ -14,7 +14,9 @@ from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError # Import GroupPermissionError
)
# --- Group CRUD ---
@ -152,4 +154,75 @@ async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
async def check_group_membership(
db: AsyncSession,
group_id: int,
user_id: int,
action: str = "access this group"
) -> None:
"""
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
"""
try:
async with db.begin():
# Check group existence first
group_exists = await db.get(GroupModel, group_id)
if not group_exists:
raise GroupNotFoundError(group_id)
# Check membership
membership = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
if membership.scalar_one_or_none() is None:
raise GroupMembershipError(group_id, action=action)
# If we reach here, the user is a member
return None
except GroupNotFoundError: # Re-raise specific errors
raise
except GroupMembershipError:
raise
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def check_user_role_in_group(
db: AsyncSession,
group_id: int,
user_id: int,
required_role: UserRoleEnum,
action: str = "perform this action"
) -> None:
"""
Checks if a user is a member of a group and has the required role (or higher).
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
GroupPermissionError: If the user does not have the required role.
"""
# First, ensure user is a member (this also checks group existence)
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
# Get the user's actual role
actual_role = await get_user_role_in_group(db, group_id, user_id)
# Define role hierarchy (assuming owner > member)
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
raise GroupPermissionError(
group_id=group_id,
action=f"{action} (requires at least '{required_role.value}' role)"
)
# If role is sufficient, return None
return None

168
be/app/crud/settlement.py Normal file
View File

@ -0,0 +1,168 @@
# app/crud/settlement.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_
from decimal import Decimal
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
from app.models import (
Settlement as SettlementModel,
User as UserModel,
Group as GroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
# Validate Payer, Payee, and Group exist
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Optional: Check if current_user_id is part of the group or is one of the parties involved
# This is more of an API-level permission check but could be added here if strict.
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
# if not is_in_group.first():
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
try:
await db.commit()
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
return db_settlement
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
.where(SettlementModel.id == settlement_id)
)
return result.scalars().first()
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
)
return result.scalars().all()
async def get_settlements_involving_user(
db: AsyncSession,
user_id: int,
group_id: Optional[int] = None,
skip: int = 0,
limit: int = 100
) -> Sequence[SettlementModel]:
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
"""
Updates an existing settlement.
Only allows updates to description and settlement_date.
Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field.
"""
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version does not match current version {settlement_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
else:
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
pass # No actual updatable fields provided, but version matched.
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
settlement_db.updated_at = datetime.now(timezone.utc)
try:
await db.commit()
await db.refresh(settlement_db)
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
return settlement_db
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
"""
Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field.
"""
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(settlement_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
return None
# TODO: Implement update_settlement (consider restrictions, versioning)
# TODO: Implement delete_settlement (consider implications on balances)

View File

@ -21,7 +21,7 @@ from sqlalchemy import (
Text, # <-- Add Text for description
Numeric # <-- Add Numeric for price
)
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, backref
from .database import Base
@ -30,6 +30,14 @@ class UserRoleEnum(enum.Enum):
owner = "owner"
member = "member"
class SplitTypeEnum(enum.Enum):
EQUAL = "EQUAL" # Split equally among all involved users
EXACT_AMOUNTS = "EXACT_AMOUNTS" # Specific amounts for each user (defined in ExpenseSplit)
PERCENTAGE = "PERCENTAGE" # Percentage for each user (defined in ExpenseSplit)
SHARES = "SHARES" # Proportional to shares/units (defined in ExpenseSplit)
ITEM_BASED = "ITEM_BASED" # If an expense is derived directly from item prices and who added them
# Add more types as needed, e.g., UNPAID (for tracking debts not part of a formal expense)
# --- User Model ---
class User(Base):
__tablename__ = "users"
@ -51,6 +59,13 @@ class User(Base):
completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User
# --- End NEW Relationships ---
# --- Relationships for Cost Splitting ---
expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan")
expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan")
settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan")
settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- Group Model ---
class Group(Base):
@ -70,6 +85,11 @@ class Group(Base):
lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group
# --- End NEW Relationship ---
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan")
settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# --- UserGroup Association Model ---
class UserGroup(Base):
@ -124,6 +144,10 @@ class List(Base):
group = relationship("Group", back_populates="lists") # Link to Group.lists
items = relationship("Item", back_populates="list", cascade="all, delete-orphan", order_by="Item.created_at") # Link to Item.list, cascade deletes
# --- Relationships for Cost Splitting ---
expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan")
# --- End Relationships for Cost Splitting ---
# === NEW: Item Model ===
class Item(Base):
@ -144,4 +168,96 @@ class Item(Base):
# --- Relationships ---
list = relationship("List", back_populates="items") # Link to List.items
added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items
# --- Relationships for Cost Splitting ---
# If an item directly results in an expense, or an expense can be tied to an item.
expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses
# --- End Relationships for Cost Splitting ---
# === NEW Models for Advanced Cost Splitting ===
class Expense(Base):
__tablename__ = "expenses"
id = Column(Integer, primary_key=True, index=True)
description = Column(String, nullable=False)
total_amount = Column(Numeric(10, 2), nullable=False)
currency = Column(String, nullable=False, default="USD") # Consider making this an Enum too if few currencies
expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False)
# Foreign Keys
list_id = Column(Integer, ForeignKey("lists.id"), nullable=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # If not list-specific but group-specific
item_id = Column(Integer, ForeignKey("items.id"), nullable=True) # If the expense is for a specific item
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid")
list = relationship("List", foreign_keys=[list_id], back_populates="expenses")
group = relationship("Group", foreign_keys=[group_id], back_populates="expenses")
item = relationship("Item", foreign_keys=[item_id], back_populates="expenses")
splits = relationship("ExpenseSplit", back_populates="expense", cascade="all, delete-orphan")
__table_args__ = (
# Example: Ensure either list_id or group_id is present if item_id is null
# CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'),
)
class ExpenseSplit(Base):
__tablename__ = "expense_splits"
__table_args__ = (UniqueConstraint('expense_id', 'user_id', name='uq_expense_user_split'),)
id = Column(Integer, primary_key=True, index=True)
expense_id = Column(Integer, ForeignKey("expenses.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
owed_amount = Column(Numeric(10, 2), nullable=False) # For EQUAL or EXACT_AMOUNTS
# For PERCENTAGE split (value from 0.00 to 100.00)
share_percentage = Column(Numeric(5, 2), nullable=True)
# For SHARES split (e.g., user A has 2 shares, user B has 3 shares)
share_units = Column(Integer, nullable=True)
# is_settled might be better tracked via actual Settlement records or a reconciliation process
# is_settled = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
# Relationships
expense = relationship("Expense", back_populates="splits")
user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits")
class Settlement(Base):
__tablename__ = "settlements"
id = Column(Integer, primary_key=True, index=True)
group_id = Column(Integer, ForeignKey("groups.id"), nullable=False) # Settlements usually within a group
paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
paid_to_user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
amount = Column(Numeric(10, 2), nullable=False)
settlement_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
description = Column(Text, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
version = Column(Integer, nullable=False, default=1, server_default='1')
# Relationships
group = relationship("Group", foreign_keys=[group_id], back_populates="settlements")
payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made")
payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received")
__table_args__ = (
# Ensure payer and payee are different users
# CheckConstraint('paid_by_user_id <> paid_to_user_id', name='chk_settlement_payer_ne_payee'),
)
# Potential future: PaymentMethod model, etc.

View File

@ -19,4 +19,37 @@ class ListCostSummary(BaseModel):
equal_share_per_user: Decimal
user_balances: List[UserCostShare]
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)
class UserBalanceDetail(BaseModel):
user_id: int
user_identifier: str # Name or email
total_paid_for_expenses: Decimal = Decimal("0.00")
total_share_of_expenses: Decimal = Decimal("0.00")
total_settlements_paid: Decimal = Decimal("0.00")
total_settlements_received: Decimal = Decimal("0.00")
net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid)
model_config = ConfigDict(from_attributes=True)
class SuggestedSettlement(BaseModel):
from_user_id: int
from_user_identifier: str # Name or email of payer
to_user_id: int
to_user_identifier: str # Name or email of payee
amount: Decimal
model_config = ConfigDict(from_attributes=True)
class GroupBalanceSummary(BaseModel):
group_id: int
group_name: str
overall_total_expenses: Decimal = Decimal("0.00")
overall_total_settlements: Decimal = Decimal("0.00")
user_balances: List[UserBalanceDetail]
# Optional: Could add a list of suggested settlements to zero out balances
suggested_settlements: Optional[List[SuggestedSettlement]] = None
model_config = ConfigDict(from_attributes=True)
# class SuggestedSettlement(BaseModel):
# from_user_id: int
# to_user_id: int
# amount: Decimal

131
be/app/schemas/expense.py Normal file
View File

@ -0,0 +1,131 @@
# app/schemas/expense.py
from pydantic import BaseModel, ConfigDict, validator
from typing import List, Optional
from decimal import Decimal
from datetime import datetime
# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums
# For now, let's redefine it or import it if models.py is parsable by Pydantic directly
# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it.
# For simplicity during schema definition, I'll redefine a string enum here.
# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py.
from app.models import SplitTypeEnum # Try importing directly
# --- ExpenseSplit Schemas ---
class ExpenseSplitBase(BaseModel):
user_id: int
owed_amount: Decimal
share_percentage: Optional[Decimal] = None
share_units: Optional[int] = None
class ExpenseSplitCreate(ExpenseSplitBase):
pass # All fields from base are needed for creation
class ExpenseSplitPublic(ExpenseSplitBase):
id: int
expense_id: int
# user: Optional[UserPublic] # If we want to nest user details
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
# --- Expense Schemas ---
class ExpenseBase(BaseModel):
description: str
total_amount: Decimal
currency: Optional[str] = "USD"
expense_date: Optional[datetime] = None
split_type: SplitTypeEnum
list_id: Optional[int] = None
group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa
item_id: Optional[int] = None
paid_by_user_id: int
class ExpenseCreate(ExpenseBase):
# For EQUAL split, splits are generated. For others, they might be provided.
# This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES,
# then 'splits_in' should be provided.
splits_in: Optional[List[ExpenseSplitCreate]] = None
@validator('total_amount')
def total_amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Total amount must be positive')
return v
# Basic validation: if list_id is None, group_id must be provided.
# More complex cross-field validation might be needed.
@validator('group_id', always=True)
def check_list_or_group_id(cls, v, values):
if values.get('list_id') is None and v is None:
raise ValueError('Either list_id or group_id must be provided for an expense')
return v
class ExpenseUpdate(BaseModel):
description: Optional[str] = None
total_amount: Optional[Decimal] = None
currency: Optional[str] = None
expense_date: Optional[datetime] = None
split_type: Optional[SplitTypeEnum] = None
list_id: Optional[int] = None
group_id: Optional[int] = None
item_id: Optional[int] = None
# paid_by_user_id is usually not updatable directly to maintain integrity.
# Updating splits would be a more complex operation, potentially a separate endpoint or careful logic.
version: int # For optimistic locking
class ExpensePublic(ExpenseBase):
id: int
created_at: datetime
updated_at: datetime
version: int
splits: List[ExpenseSplitPublic] = []
# paid_by_user: Optional[UserPublic] # If nesting user details
# list: Optional[ListPublic] # If nesting list details
# group: Optional[GroupPublic] # If nesting group details
# item: Optional[ItemPublic] # If nesting item details
model_config = ConfigDict(from_attributes=True)
# --- Settlement Schemas ---
class SettlementBase(BaseModel):
group_id: int
paid_by_user_id: int
paid_to_user_id: int
amount: Decimal
settlement_date: Optional[datetime] = None
description: Optional[str] = None
class SettlementCreate(SettlementBase):
@validator('amount')
def amount_must_be_positive(cls, v):
if v <= Decimal('0'):
raise ValueError('Settlement amount must be positive')
return v
@validator('paid_to_user_id')
def payer_and_payee_must_be_different(cls, v, values):
if 'paid_by_user_id' in values and v == values['paid_by_user_id']:
raise ValueError('Payer and payee cannot be the same user')
return v
class SettlementUpdate(BaseModel):
amount: Optional[Decimal] = None
settlement_date: Optional[datetime] = None
description: Optional[str] = None
# group_id, paid_by_user_id, paid_to_user_id are typically not updatable.
version: int # For optimistic locking
class SettlementPublic(SettlementBase):
id: int
created_at: datetime
updated_at: datetime
# payer: Optional[UserPublic]
# payee: Optional[UserPublic]
# group: Optional[GroupPublic]
model_config = ConfigDict(from_attributes=True)
# Placeholder for nested schemas (e.g., UserPublic) if needed
# from app.schemas.user import UserPublic
# from app.schemas.list import ListPublic
# from app.schemas.group import GroupPublic
# from app.schemas.item import ItemPublic

View File

@ -0,0 +1,373 @@
import pytest
from fastapi import status
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Callable, Dict, Any
from app.models import User as UserModel, Group as GroupModel, List as ListModel
from app.schemas.expense import ExpenseCreate
from app.core.config import settings
# Helper to create a URL for an endpoint
API_V1_STR = settings.API_V1_STR
def expense_url(endpoint: str = "") -> str:
return f"{API_V1_STR}/financials/expenses{endpoint}"
def settlement_url(endpoint: str = "") -> str:
return f"{API_V1_STR}/financials/settlements{endpoint}"
@pytest.mark.asyncio
async def test_create_new_expense_success_list_context(
client: AsyncClient,
db_session: AsyncSession, # Assuming a fixture for db session
normal_user_token_headers: Dict[str, str], # Assuming a fixture for user auth
test_user: UserModel, # Assuming a fixture for a test user
test_list_user_is_member: ListModel, # Assuming a fixture for a list user is member of
) -> None:
"""
Test successful creation of a new expense linked to a list.
"""
expense_data = ExpenseCreate(
description="Test Expense for List",
amount=100.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=test_list_user_is_member.id,
group_id=None, # group_id should be derived from list if list is in a group
# category_id: Optional[int] = None # Assuming category is optional
# expense_date: Optional[date] = None # Assuming date is optional
# splits: Optional[List[SplitCreate]] = [] # Assuming splits are optional for now
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["amount"] == expense_data.amount
assert content["currency"] == expense_data.currency
assert content["paid_by_user_id"] == test_user.id
assert content["list_id"] == test_list_user_is_member.id
# If test_list_user_is_member has a group_id, it should be set in the response
if test_list_user_is_member.group_id:
assert content["group_id"] == test_list_user_is_member.group_id
else:
assert content["group_id"] is None
assert "id" in content
assert "created_at" in content
assert "updated_at" in content
assert "version" in content
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_success_group_context(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel, # Assuming a fixture for a group user is member of
) -> None:
"""
Test successful creation of a new expense linked directly to a group.
"""
expense_data = ExpenseCreate(
description="Test Expense for Group",
amount=50.00,
currency="EUR",
paid_by_user_id=test_user.id,
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_201_CREATED
content = response.json()
assert content["description"] == expense_data.description
assert content["paid_by_user_id"] == test_user.id
assert content["group_id"] == test_group_user_is_member.id
assert content["list_id"] is None
assert content["version"] == 1
@pytest.mark.asyncio
async def test_create_new_expense_fail_no_list_or_group(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
) -> None:
"""
Test expense creation fails if neither list_id nor group_id is provided.
"""
expense_data = ExpenseCreate(
description="Test Invalid Expense",
amount=10.00,
currency="USD",
paid_by_user_id=test_user.id,
list_id=None,
group_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers,
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
content = response.json()
assert "Expense must be linked to a list_id or group_id" in content["detail"]
@pytest.mark.asyncio
async def test_create_new_expense_fail_paid_by_other_not_owner(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # User is member, not owner
test_user: UserModel, # This is the current_user (member)
test_group_user_is_member: GroupModel, # Group the current_user is a member of
another_user_in_group: UserModel, # Another user in the same group
# Ensure test_user is NOT an owner of test_group_user_is_member for this test
) -> None:
"""
Test creation fails if paid_by_user_id is another user, and current_user is not a group owner.
Assumes normal_user_token_headers belongs to a user who is a member but not an owner of test_group_user_is_member.
"""
expense_data = ExpenseCreate(
description="Expense paid by other",
amount=75.00,
currency="GBP",
paid_by_user_id=another_user_in_group.id, # Paid by someone else
group_id=test_group_user_is_member.id,
list_id=None,
)
response = await client.post(
expense_url(),
headers=normal_user_token_headers, # Current user is a member, not owner
json=expense_data.model_dump(exclude_unset=True)
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Only group owners can create expenses paid by others" in content["detail"]
# --- Add tests for other endpoints below ---
# GET /expenses/{expense_id}
@pytest.mark.asyncio
async def test_get_expense_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
# Assume an existing expense created by test_user or in a group/list they have access to
# This would typically be created by another test or a fixture
created_expense: ExpensePublic, # Assuming a fixture that provides a created expense
) -> None:
"""
Test successfully retrieving an existing expense.
User has access either by being the payer, or via list/group membership.
"""
response = await client.get(
expense_url(f"/{created_expense.id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert content["id"] == created_expense.id
assert content["description"] == created_expense.description
assert content["amount"] == created_expense.amount
assert content["paid_by_user_id"] == created_expense.paid_by_user_id
if created_expense.list_id:
assert content["list_id"] == created_expense.list_id
if created_expense.group_id:
assert content["group_id"] == created_expense.group_id
# TODO: Add more tests for get_expense:
# - expense not found -> 404
# - user has no access (not payer, not in list, not in group if applicable) -> 403
# - expense in list, user has list access
# - expense in group, user has group access
# - expense personal (no list, no group), user is payer
# - expense personal (no list, no group), user is NOT payer -> 403
@pytest.mark.asyncio
async def test_get_expense_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test retrieving a non-existent expense results in 404.
"""
non_existent_expense_id = 9999999
response = await client.get(
expense_url(f"/{non_existent_expense_id}"),
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_404_NOT_FOUND
content = response.json()
assert "not found" in content["detail"].lower()
@pytest.mark.asyncio
async def test_get_expense_forbidden_personal_expense_other_user(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # Belongs to test_user
# Fixture for an expense paid by another_user, not linked to any list/group test_user has access to
personal_expense_of_another_user: ExpensePublic
) -> None:
"""
Test retrieving a personal expense of another user (no shared list/group) results in 403.
"""
response = await client.get(
expense_url(f"/{personal_expense_of_another_user.id}"),
headers=normal_user_token_headers # Current user querying
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert "Not authorized to view this expense" in content["detail"]
# GET /lists/{list_id}/expenses
@pytest.mark.asyncio
async def test_list_list_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_list_user_is_member: ListModel, # List the user is a member of
# Assume some expenses have been created for this list by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a list the user has access to.
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
for expense_item in content: # Renamed from expense to avoid conflict if a fixture is named expense
assert expense_item["list_id"] == test_list_user_is_member.id
# TODO: Add more tests for list_list_expenses:
# - list not found -> 404 (ListNotFoundError from check_list_access_for_financials)
# - user has no access to list -> 403 (ListPermissionError from check_list_access_for_financials)
# - list exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
@pytest.mark.asyncio
async def test_list_list_expenses_list_not_found(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
) -> None:
"""
Test listing expenses for a non-existent list results in 404 (or appropriate error from permission check).
The check_list_access_for_financials raises ListNotFoundError, which might be caught and raised as 404.
The endpoint itself also has a get for ListModel, which would 404 first if permission check passed (not possible here).
Based on financials.py, ListNotFoundError is raised by check_list_access_for_financials.
This should translate to a 404 or a 403 if ListPermissionError wraps it with an action.
The current ListPermissionError in check_list_access_for_financials re-raises ListNotFoundError if that's the cause.
ListNotFoundError is a custom exception often mapped to 404.
Let's assume ListNotFoundError results in a 404 response from an exception handler.
"""
non_existent_list_id = 99999
response = await client.get(
f"{API_V1_STR}/financials/lists/{non_existent_list_id}/expenses",
headers=normal_user_token_headers
)
# The ListNotFoundError is raised by the check_list_access_for_financials helper,
# which is then re-raised. FastAPI default exception handlers or custom ones
# would convert this to an HTTP response. Typically NotFoundError -> 404.
# If ListPermissionError catches it and re-raises it specifically, it might be 403.
# From the code: `except ListNotFoundError: raise` means it propagates.
# Let's assume a global handler for NotFoundError derived exceptions leads to 404.
assert response.status_code == status.HTTP_404_NOT_FOUND
# The actual detail might vary based on how ListNotFoundError is handled by FastAPI
# For now, we check the status code. If financials.py maps it differently, this will need adjustment.
# Based on `raise ListNotFoundError(expense_in.list_id)` in create_new_expense, and if that leads to 400,
# this might be inconsistent. However, `check_list_access_for_financials` just re-raises ListNotFoundError.
# Let's stick to expecting 404 for a direct not found error from a path parameter.
content = response.json()
assert "list not found" in content["detail"].lower() # Common detail for not found errors
@pytest.mark.asyncio
async def test_list_list_expenses_no_access(
client: AsyncClient,
normal_user_token_headers: Dict[str, str], # User who will attempt access
test_list_user_not_member: ListModel, # A list current user is NOT a member of
) -> None:
"""
Test listing expenses for a list the user does not have access to (403 Forbidden).
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_not_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = response.json()
assert f"User does not have permission to access financial data for list {test_list_user_not_member.id}" in content["detail"]
@pytest.mark.asyncio
async def test_list_list_expenses_empty(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_list_user_is_member_no_expenses: ListModel, # List user is member of, but has no expenses
) -> None:
"""
Test listing expenses for an accessible list that has no expenses (empty list, 200 OK).
"""
response = await client.get(
f"{API_V1_STR}/financials/lists/{test_list_user_is_member_no_expenses.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
assert len(content) == 0
# GET /groups/{group_id}/expenses
@pytest.mark.asyncio
async def test_list_group_expenses_success(
client: AsyncClient,
normal_user_token_headers: Dict[str, str],
test_user: UserModel,
test_group_user_is_member: GroupModel, # Group the user is a member of
# Assume some expenses have been created for this group by a fixture or previous tests
) -> None:
"""
Test successfully listing expenses for a group the user has access to.
"""
response = await client.get(
f"{API_V1_STR}/financials/groups/{test_group_user_is_member.id}/expenses",
headers=normal_user_token_headers
)
assert response.status_code == status.HTTP_200_OK
content = response.json()
assert isinstance(content, list)
# Further assertions can be made here, e.g., checking if all expenses belong to the group
for expense_item in content:
assert expense_item["group_id"] == test_group_user_is_member.id
# Expenses in a group might also have a list_id if they were added via a list belonging to that group
# TODO: Add more tests for list_group_expenses:
# - group not found -> 404 (GroupNotFoundError from check_group_membership)
# - user has no access to group (not a member) -> 403 (GroupMembershipError from check_group_membership)
# - group exists but has no expenses -> empty list, 200 OK
# - test pagination (skip, limit)
# PUT /expenses/{expense_id}
# DELETE /expenses/{expense_id}
# GET /settlements/{settlement_id}
# POST /settlements
# GET /groups/{group_id}/settlements
# PUT /settlements/{settlement_id}
# DELETE /settlements/{settlement_id}
pytest.skip("Still implementing other tests", allow_module_level=True)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,345 @@
import pytest
from fastapi import status
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
GroupNotFoundError,
GroupPermissionError,
GroupMembershipError,
GroupOperationError,
GroupValidationError,
ItemNotFoundError,
UserNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseTransactionError,
DatabaseQueryError,
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError,
InvalidFileTypeError,
FileTooLargeError,
OCRProcessingError,
EmailAlreadyRegisteredError,
UserCreationError,
InviteNotFoundError,
InviteExpiredError,
InviteAlreadyUsedError,
InviteCreationError,
ListStatusNotFoundError,
ConflictError,
InvalidCredentialsError,
NotAuthenticatedError,
JWTError,
JWTUnexpectedError
)
# TODO: It seems like settings are used in some exceptions.
# You will need to mock app.config.settings for these tests to pass.
# Consider using pytest-mock or unittest.mock.patch.
# Example: from app.config import settings
def test_list_not_found_error():
list_id = 123
with pytest.raises(ListNotFoundError) as excinfo:
raise ListNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"List {list_id} not found"
def test_list_permission_error():
list_id = 456
action = "delete"
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} list {list_id}"
def test_list_permission_error_default_action():
list_id = 789
with pytest.raises(ListPermissionError) as excinfo:
raise ListPermissionError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to access list {list_id}"
def test_list_creator_required_error():
list_id = 101
action = "update"
with pytest.raises(ListCreatorRequiredError) as excinfo:
raise ListCreatorRequiredError(list_id=list_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"Only the list creator can {action} list {list_id}"
def test_group_not_found_error():
group_id = 202
with pytest.raises(GroupNotFoundError) as excinfo:
raise GroupNotFoundError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Group {group_id} not found"
def test_group_permission_error():
group_id = 303
action = "invite"
with pytest.raises(GroupPermissionError) as excinfo:
raise GroupPermissionError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You do not have permission to {action} in group {group_id}"
def test_group_membership_error():
group_id = 404
action = "post"
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id, action=action)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to {action}"
def test_group_membership_error_default_action():
group_id = 505
with pytest.raises(GroupMembershipError) as excinfo:
raise GroupMembershipError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN
assert excinfo.value.detail == f"You must be a member of group {group_id} to access"
def test_group_operation_error():
detail_msg = "Failed to perform group operation."
with pytest.raises(GroupOperationError) as excinfo:
raise GroupOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == detail_msg
def test_group_validation_error():
detail_msg = "Invalid group data."
with pytest.raises(GroupValidationError) as excinfo:
raise GroupValidationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_item_not_found_error():
item_id = 606
with pytest.raises(ItemNotFoundError) as excinfo:
raise ItemNotFoundError(item_id=item_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Item {item_id} not found"
def test_user_not_found_error_no_identifier():
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError()
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == "User not found."
def test_user_not_found_error_with_id():
user_id = 707
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(user_id=user_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with ID {user_id} not found."
def test_user_not_found_error_with_identifier_string():
identifier = "test_user"
with pytest.raises(UserNotFoundError) as excinfo:
raise UserNotFoundError(identifier=identifier)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"User with identifier '{identifier}' not found."
def test_invalid_operation_error():
detail_msg = "This operation is not allowed."
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == detail_msg
def test_invalid_operation_error_custom_status():
detail_msg = "This operation is forbidden."
custom_status = status.HTTP_403_FORBIDDEN
with pytest.raises(InvalidOperationError) as excinfo:
raise InvalidOperationError(detail=detail_msg, status_code=custom_status)
assert excinfo.value.status_code == custom_status
assert excinfo.value.detail == detail_msg
# The following exceptions depend on `settings`
# We need to mock `app.config.settings` for these tests.
# For now, I will add placeholder tests that would fail without mocking.
# Consider using pytest-mock or unittest.mock.patch for this.
# def test_database_connection_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_CONNECTION_ERROR", "Test DB connection error")
# with pytest.raises(DatabaseConnectionError) as excinfo:
# raise DatabaseConnectionError()
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
# assert excinfo.value.detail == "Test DB connection error" # settings.DB_CONNECTION_ERROR
# def test_database_integrity_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_INTEGRITY_ERROR", "Test DB integrity error")
# with pytest.raises(DatabaseIntegrityError) as excinfo:
# raise DatabaseIntegrityError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == "Test DB integrity error" # settings.DB_INTEGRITY_ERROR
# def test_database_transaction_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_TRANSACTION_ERROR", "Test DB transaction error")
# with pytest.raises(DatabaseTransactionError) as excinfo:
# raise DatabaseTransactionError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test DB transaction error" # settings.DB_TRANSACTION_ERROR
# def test_database_query_error(mocker):
# mocker.patch("app.core.exceptions.settings.DB_QUERY_ERROR", "Test DB query error")
# with pytest.raises(DatabaseQueryError) as excinfo:
# raise DatabaseQueryError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test DB query error" # settings.DB_QUERY_ERROR
# def test_ocr_service_unavailable_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_UNAVAILABLE", "Test OCR unavailable")
# with pytest.raises(OCRServiceUnavailableError) as excinfo:
# raise OCRServiceUnavailableError()
# assert excinfo.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
# assert excinfo.value.detail == "Test OCR unavailable" # settings.OCR_SERVICE_UNAVAILABLE
# def test_ocr_service_config_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_SERVICE_CONFIG_ERROR", "Test OCR config error")
# with pytest.raises(OCRServiceConfigError) as excinfo:
# raise OCRServiceConfigError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test OCR config error" # settings.OCR_SERVICE_CONFIG_ERROR
# def test_ocr_unexpected_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_UNEXPECTED_ERROR", "Test OCR unexpected error")
# with pytest.raises(OCRUnexpectedError) as excinfo:
# raise OCRUnexpectedError()
# assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# assert excinfo.value.detail == "Test OCR unexpected error" # settings.OCR_UNEXPECTED_ERROR
# def test_ocr_quota_exceeded_error(mocker):
# mocker.patch("app.core.exceptions.settings.OCR_QUOTA_EXCEEDED", "Test OCR quota exceeded")
# with pytest.raises(OCRQuotaExceededError) as excinfo:
# raise OCRQuotaExceededError()
# assert excinfo.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
# assert excinfo.value.detail == "Test OCR quota exceeded" # settings.OCR_QUOTA_EXCEEDED
# def test_invalid_file_type_error(mocker):
# test_types = ["png", "jpg"]
# mocker.patch("app.core.exceptions.settings.ALLOWED_IMAGE_TYPES", test_types)
# mocker.patch("app.core.exceptions.settings.OCR_INVALID_FILE_TYPE", "Invalid type: {types}")
# with pytest.raises(InvalidFileTypeError) as excinfo:
# raise InvalidFileTypeError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"Invalid type: {', '.join(test_types)}" # settings.OCR_INVALID_FILE_TYPE.format(types=", ".join(settings.ALLOWED_IMAGE_TYPES))
# def test_file_too_large_error(mocker):
# max_size = 10
# mocker.patch("app.core.exceptions.settings.MAX_FILE_SIZE_MB", max_size)
# mocker.patch("app.core.exceptions.settings.OCR_FILE_TOO_LARGE", "File too large: {size}MB")
# with pytest.raises(FileTooLargeError) as excinfo:
# raise FileTooLargeError()
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"File too large: {max_size}MB" # settings.OCR_FILE_TOO_LARGE.format(size=settings.MAX_FILE_SIZE_MB)
# def test_ocr_processing_error(mocker):
# error_detail = "Specific OCR error"
# mocker.patch("app.core.exceptions.settings.OCR_PROCESSING_ERROR", "OCR processing failed: {detail}")
# with pytest.raises(OCRProcessingError) as excinfo:
# raise OCRProcessingError(detail=error_detail)
# assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
# assert excinfo.value.detail == f"OCR processing failed: {error_detail}" # settings.OCR_PROCESSING_ERROR.format(detail=detail)
def test_email_already_registered_error():
with pytest.raises(EmailAlreadyRegisteredError) as excinfo:
raise EmailAlreadyRegisteredError()
assert excinfo.value.status_code == status.HTTP_400_BAD_REQUEST
assert excinfo.value.detail == "Email already registered."
def test_user_creation_error():
with pytest.raises(UserCreationError) as excinfo:
raise UserCreationError()
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == "An error occurred during user creation."
def test_invite_not_found_error():
invite_code = "TESTCODE123"
with pytest.raises(InviteNotFoundError) as excinfo:
raise InviteNotFoundError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Invite code {invite_code} not found"
def test_invite_expired_error():
invite_code = "EXPIREDCODE"
with pytest.raises(InviteExpiredError) as excinfo:
raise InviteExpiredError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has expired"
def test_invite_already_used_error():
invite_code = "USEDCODE"
with pytest.raises(InviteAlreadyUsedError) as excinfo:
raise InviteAlreadyUsedError(invite_code=invite_code)
assert excinfo.value.status_code == status.HTTP_410_GONE
assert excinfo.value.detail == f"Invite code {invite_code} has already been used"
def test_invite_creation_error():
group_id = 909
with pytest.raises(InviteCreationError) as excinfo:
raise InviteCreationError(group_id=group_id)
assert excinfo.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert excinfo.value.detail == f"Failed to create invite for group {group_id}"
def test_list_status_not_found_error():
list_id = 808
with pytest.raises(ListStatusNotFoundError) as excinfo:
raise ListStatusNotFoundError(list_id=list_id)
assert excinfo.value.status_code == status.HTTP_404_NOT_FOUND
assert excinfo.value.detail == f"Status for list {list_id} not found"
def test_conflict_error():
detail_msg = "Resource version mismatch."
with pytest.raises(ConflictError) as excinfo:
raise ConflictError(detail=detail_msg)
assert excinfo.value.status_code == status.HTTP_409_CONFLICT
assert excinfo.value.detail == detail_msg
# Tests for auth-related exceptions that likely require mocking app.config.settings
# def test_invalid_credentials_error(mocker):
# mocker.patch("app.core.exceptions.settings.AUTH_INVALID_CREDENTIALS", "Invalid test credentials")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(InvalidCredentialsError) as excinfo:
# raise InvalidCredentialsError()
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == "Invalid test credentials"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_credentials\""}
# def test_not_authenticated_error(mocker):
# mocker.patch("app.core.exceptions.settings.AUTH_NOT_AUTHENTICATED", "Not authenticated test")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(NotAuthenticatedError) as excinfo:
# raise NotAuthenticatedError()
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == "Not authenticated test"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"not_authenticated\""}
# def test_jwt_error(mocker):
# error_msg = "Test JWT issue"
# mocker.patch("app.core.exceptions.settings.JWT_ERROR", "JWT error: {error}")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(JWTError) as excinfo:
# raise JWTError(error=error_msg)
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == f"JWT error: {error_msg}"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}
# def test_jwt_unexpected_error(mocker):
# error_msg = "Unexpected test JWT issue"
# mocker.patch("app.core.exceptions.settings.JWT_UNEXPECTED_ERROR", "Unexpected JWT error: {error}")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_NAME", "X-Test-Auth")
# mocker.patch("app.core.exceptions.settings.AUTH_HEADER_PREFIX", "TestBearer")
# with pytest.raises(JWTUnexpectedError) as excinfo:
# raise JWTUnexpectedError(error=error_msg)
# assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED
# assert excinfo.value.detail == f"Unexpected JWT error: {error_msg}"
# assert excinfo.value.headers == {"X-Test-Auth": "TestBearer error=\"invalid_token\""}

View File

@ -0,0 +1,276 @@
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
import google.generativeai as genai
from google.api_core import exceptions as google_exceptions
# Modules to test
from app.core import gemini
from app.core.exceptions import (
OCRServiceUnavailableError,
OCRServiceConfigError,
OCRUnexpectedError,
OCRQuotaExceededError
)
# Default Mock Settings
@pytest.fixture
def mock_gemini_settings():
settings_mock = MagicMock()
settings_mock.GEMINI_API_KEY = "test_api_key"
settings_mock.GEMINI_MODEL_NAME = "gemini-pro-vision"
settings_mock.GEMINI_SAFETY_SETTINGS = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
}
settings_mock.GEMINI_GENERATION_CONFIG = {"temperature": 0.7}
settings_mock.OCR_ITEM_EXTRACTION_PROMPT = "Extract items:"
return settings_mock
@pytest.fixture
def mock_generative_model_instance():
model_instance = MagicMock(spec=genai.GenerativeModel)
model_instance.generate_content_async = AsyncMock()
return model_instance
@pytest.fixture
@patch('google.generativeai.GenerativeModel')
@patch('google.generativeai.configure')
def patch_google_ai_client(mock_configure, mock_generative_model, mock_generative_model_instance):
mock_generative_model.return_value = mock_generative_model_instance
return mock_configure, mock_generative_model, mock_generative_model_instance
# --- Test Gemini Client Initialization (Global Client) ---
# Parametrize to test different scenarios for the global client init
@pytest.mark.parametrize(
"api_key_present, configure_raises, model_init_raises, expected_error_message_part",
[
(True, None, None, None), # Success
(False, None, None, "GEMINI_API_KEY not configured"), # API key missing
(True, Exception("Config error"), None, "Failed to initialize Gemini AI client: Config error"), # genai.configure error
(True, None, Exception("Model init error"), "Failed to initialize Gemini AI client: Model init error"), # GenerativeModel error
]
)
@patch('app.core.gemini.genai') # Patch genai within the gemini module
def test_global_gemini_client_initialization(
mock_genai_module,
mock_gemini_settings,
api_key_present,
configure_raises,
model_init_raises,
expected_error_message_part
):
"""Tests the global gemini_flash_client initialization logic in app.core.gemini."""
# We need to reload the module to re-trigger its top-level initialization code.
# This is a bit tricky. A common pattern is to put init logic in a function.
# For now, we'll try to simulate it by controlling mocks before module access.
with patch('app.core.gemini.settings', mock_gemini_settings):
if not api_key_present:
mock_gemini_settings.GEMINI_API_KEY = None
mock_genai_module.configure = MagicMock()
mock_genai_module.GenerativeModel = MagicMock()
mock_genai_module.types = genai.types # Keep original types
mock_genai_module.HarmCategory = genai.HarmCategory
mock_genai_module.HarmBlockThreshold = genai.HarmBlockThreshold
if configure_raises:
mock_genai_module.configure.side_effect = configure_raises
if model_init_raises:
mock_genai_module.GenerativeModel.side_effect = model_init_raises
# Python modules are singletons. To re-run top-level code, we need to unload and reload.
# This is generally discouraged. It's better to have an explicit init function.
# For this test, we'll check the state variables set by the module's import-time code.
import importlib
importlib.reload(gemini) # This re-runs the try-except block at the top of gemini.py
if expected_error_message_part:
assert gemini.gemini_initialization_error is not None
assert expected_error_message_part in gemini.gemini_initialization_error
assert gemini.gemini_flash_client is None
else:
assert gemini.gemini_initialization_error is None
assert gemini.gemini_flash_client is not None
mock_genai_module.configure.assert_called_once_with(api_key="test_api_key")
mock_genai_module.GenerativeModel.assert_called_once()
# Could add more assertions about safety_settings and generation_config here
# Clean up after reload for other tests
importlib.reload(gemini)
# --- Test get_gemini_client ---
# Assuming the global client tests above set the stage for these
@patch('app.core.gemini.gemini_flash_client', new_callable=MagicMock)
@patch('app.core.gemini.gemini_initialization_error', None)
def test_get_gemini_client_success(mock_client_var, mock_error_var):
mock_client_var.return_value = MagicMock(spec=genai.GenerativeModel) # Simulate an initialized client
gemini.gemini_flash_client = mock_client_var # Assign the mock
gemini.gemini_initialization_error = None
client = gemini.get_gemini_client()
assert client is not None
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', "Test init error")
def test_get_gemini_client_init_error(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = "Test init error"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Test init error"):
gemini.get_gemini_client()
@patch('app.core.gemini.gemini_flash_client', None)
@patch('app.core.gemini.gemini_initialization_error', None) # No init error, but client is None
def test_get_gemini_client_none_client_unknown_issue(mock_client_var, mock_error_var):
gemini.gemini_flash_client = None
gemini.gemini_initialization_error = None
with pytest.raises(RuntimeError, match="Gemini client is not available \(unknown initialization issue\)."):
gemini.get_gemini_client()
# --- Tests for extract_items_from_image_gemini --- (Simplified for brevity, needs more cases)
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_success(
mock_gemini_settings,
mock_generative_model_instance,
patch_google_ai_client # This fixture patches google.generativeai for the module
):
""" Test successful item extraction """
# Ensure the global client is mocked to be the one we control
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', mock_generative_model_instance), \
patch('app.core.gemini.gemini_initialization_error', None):
mock_response = MagicMock()
mock_response.text = "Item 1\nItem 2\n Item 3 \n\nAnother Item"
# Simulate the structure for safety checks if needed
mock_candidate = MagicMock()
mock_candidate.content.parts = [MagicMock(text=mock_response.text)]
mock_candidate.finish_reason = 'STOP' # Or whatever is appropriate for success
mock_candidate.safety_ratings = []
mock_response.candidates = [mock_candidate]
mock_generative_model_instance.generate_content_async.return_value = mock_response
image_bytes = b"dummy_image_bytes"
mime_type = "image/png"
items = await gemini.extract_items_from_image_gemini(image_bytes, mime_type)
mock_generative_model_instance.generate_content_async.assert_called_once_with([
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": mime_type, "data": image_bytes}
])
assert items == ["Item 1", "Item 2", "Item 3", "Another Item"]
@pytest.mark.asyncio
async def test_extract_items_from_image_gemini_client_not_init(
mock_gemini_settings
):
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch('app.core.gemini.gemini_flash_client', None), \
patch('app.core.gemini.gemini_initialization_error', "Initialization failed explicitly"):
image_bytes = b"dummy_image_bytes"
with pytest.raises(RuntimeError, match="Gemini client could not be initialized: Initialization failed explicitly"):
await gemini.extract_items_from_image_gemini(image_bytes)
@pytest.mark.asyncio
@patch('app.core.gemini.get_gemini_client') # Mock the getter to control the client directly
async def test_extract_items_from_image_gemini_api_quota_error(
mock_get_client,
mock_gemini_settings,
mock_generative_model_instance
):
mock_get_client.return_value = mock_generative_model_instance
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota exceeded")
with patch('app.core.gemini.settings', mock_gemini_settings):
image_bytes = b"dummy_image_bytes"
with pytest.raises(google_exceptions.ResourceExhausted, match="Quota exceeded"):
await gemini.extract_items_from_image_gemini(image_bytes)
# --- Tests for GeminiOCRService --- (Example tests, more needed)
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel')
def test_gemini_ocr_service_init_success(MockGenerativeModel, MockConfigure, mock_gemini_settings, mock_generative_model_instance):
MockGenerativeModel.return_value = mock_generative_model_instance
with patch('app.core.gemini.settings', mock_gemini_settings):
service = gemini.GeminiOCRService()
MockConfigure.assert_called_once_with(api_key=mock_gemini_settings.GEMINI_API_KEY)
MockGenerativeModel.assert_called_once_with(mock_gemini_settings.GEMINI_MODEL_NAME)
assert service.model == mock_generative_model_instance
# Could add assertions for safety_settings and generation_config if they are set directly on model
@patch('app.core.gemini.genai.configure')
@patch('app.core.gemini.genai.GenerativeModel', side_effect=Exception("Init model failed"))
def test_gemini_ocr_service_init_failure(MockGenerativeModel, MockConfigure, mock_gemini_settings):
with patch('app.core.gemini.settings', mock_gemini_settings):
with pytest.raises(OCRServiceConfigError):
gemini.GeminiOCRService()
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_success(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = "Apple\nBanana\nOrange\nExample output should be ignored"
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings):
# Patch the model instance within the service for this test
with patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance) as patched_model_class,
patch.object(genai, 'configure') as patched_configure:
service = gemini.GeminiOCRService() # Re-init to use the patched model
items = await service.extract_items(b"dummy_image")
expected_call_args = [
mock_gemini_settings.OCR_ITEM_EXTRACTION_PROMPT,
{"mime_type": "image/jpeg", "data": b"dummy_image"}
]
service.model.generate_content_async.assert_called_once_with(contents=expected_call_args)
assert items == ["Apple", "Banana", "Orange"]
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_quota_error(mock_gemini_settings, mock_generative_model_instance):
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.ResourceExhausted("Quota limits exceeded.")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRQuotaExceededError):
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_api_unavailable(mock_gemini_settings, mock_generative_model_instance):
# Simulate a generic API error that isn't quota related
mock_generative_model_instance.generate_content_async.side_effect = google_exceptions.InternalServerError("Service unavailable")
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRServiceUnavailableError):
await service.extract_items(b"dummy_image")
@pytest.mark.asyncio
async def test_gemini_ocr_service_extract_items_no_text_response(mock_gemini_settings, mock_generative_model_instance):
mock_response = MagicMock()
mock_response.text = None # Simulate no text in response
mock_generative_model_instance.generate_content_async.return_value = mock_response
with patch('app.core.gemini.settings', mock_gemini_settings), \
patch.object(genai, 'GenerativeModel', return_value=mock_generative_model_instance), \
patch.object(genai, 'configure'):
service = gemini.GeminiOCRService()
with pytest.raises(OCRUnexpectedError):
await service.extract_items(b"dummy_image")

View File

@ -0,0 +1,216 @@
import pytest
from unittest.mock import patch, MagicMock
from datetime import datetime, timedelta, timezone
from jose import jwt, JWTError
from passlib.context import CryptContext
from app.core.security import (
verify_password,
hash_password,
create_access_token,
create_refresh_token,
verify_access_token,
verify_refresh_token,
pwd_context, # Import for direct testing if needed, or to check its config
)
# Assuming app.config.settings will be mocked
# from app.config import settings
# --- Tests for Password Hashing ---
def test_hash_password():
password = "securepassword123"
hashed = hash_password(password)
assert isinstance(hashed, str)
assert hashed != password
# Check that the default scheme (bcrypt) is used by verifying the hash prefix
# bcrypt hashes typically start with $2b$ or $2a$ or $2y$
assert hashed.startswith("$2b$") or hashed.startswith("$2a$") or hashed.startswith("$2y$")
def test_verify_password_correct():
password = "testpassword"
hashed_password = pwd_context.hash(password) # Use the same context for consistency
assert verify_password(password, hashed_password) is True
def test_verify_password_incorrect():
password = "testpassword"
wrong_password = "wrongpassword"
hashed_password = pwd_context.hash(password)
assert verify_password(wrong_password, hashed_password) is False
def test_verify_password_invalid_hash_format():
password = "testpassword"
invalid_hash = "notarealhash"
assert verify_password(password, invalid_hash) is False
# --- Tests for JWT Creation ---
# Mock settings for JWT tests
@pytest.fixture(scope="module")
def mock_jwt_settings():
mock_settings = MagicMock()
mock_settings.SECRET_KEY = "testsecretkey"
mock_settings.ALGORITHM = "HS256"
mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 30
mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 # 7 days
return mock_settings
@patch('app.core.security.settings')
def test_create_access_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "user@example.com"
token = create_access_token(subject)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "access"
assert "exp" in decoded_payload
# Check if expiry is roughly correct (within a small delta)
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
@patch('app.core.security.settings')
def test_create_access_token_custom_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
# ACCESS_TOKEN_EXPIRE_MINUTES is not used here due to custom delta
subject = 123 # Subject can be int
custom_delta = timedelta(hours=1)
token = create_access_token(subject, expires_delta=custom_delta)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == str(subject)
assert decoded_payload["type"] == "access"
expected_expiry = datetime.now(timezone.utc) + custom_delta
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
@patch('app.core.security.settings')
def test_create_refresh_token_default_expiry(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
subject = "refresh_subject"
token = create_refresh_token(subject)
assert isinstance(token, str)
decoded_payload = jwt.decode(token, mock_jwt_settings.SECRET_KEY, algorithms=[mock_jwt_settings.ALGORITHM])
assert decoded_payload["sub"] == subject
assert decoded_payload["type"] == "refresh"
expected_expiry = datetime.now(timezone.utc) + timedelta(minutes=mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES)
assert abs(datetime.fromtimestamp(decoded_payload["exp"], timezone.utc) - expected_expiry) < timedelta(seconds=5)
# --- Tests for JWT Verification --- (More tests to be added here)
@patch('app.core.security.settings')
def test_verify_access_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_valid_access"
token = create_access_token(subject)
payload = verify_access_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "access"
@patch('app.core.security.settings')
def test_verify_access_token_invalid_signature(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_invalid_sig"
# Create token with correct key
token = create_access_token(subject)
# Try to verify with wrong key
mock_settings_global.SECRET_KEY = "wrongsecretkey"
payload = verify_access_token(token)
assert payload is None
@patch('app.core.security.settings')
@patch('app.core.security.datetime') # Mock datetime to control token expiry
def test_verify_access_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
# Set current time for token creation
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp # Ensure original fromtimestamp is used by jwt.decode
mock_datetime.timedelta = timedelta # Ensure original timedelta is used
subject = "test_user_expired"
token = create_access_token(subject)
# Advance time beyond expiry for verification
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_access_token(token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_access_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES # For refresh token creation
subject = "test_user_wrong_type"
# Create a refresh token
refresh_token = create_refresh_token(subject)
# Try to verify it as an access token
payload = verify_access_token(refresh_token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_refresh_token_valid(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.REFRESH_TOKEN_EXPIRE_MINUTES
subject = "test_user_valid_refresh"
token = create_refresh_token(subject)
payload = verify_refresh_token(token)
assert payload is not None
assert payload["sub"] == subject
assert payload["type"] == "refresh"
@patch('app.core.security.settings')
@patch('app.core.security.datetime')
def test_verify_refresh_token_expired(mock_datetime, mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.REFRESH_TOKEN_EXPIRE_MINUTES = 1 # Expire in 1 minute
now = datetime.now(timezone.utc)
mock_datetime.now.return_value = now
mock_datetime.fromtimestamp = datetime.fromtimestamp
mock_datetime.timedelta = timedelta
subject = "test_user_expired_refresh"
token = create_refresh_token(subject)
mock_datetime.now.return_value = now + timedelta(minutes=5)
payload = verify_refresh_token(token)
assert payload is None
@patch('app.core.security.settings')
def test_verify_refresh_token_wrong_type(mock_settings_global, mock_jwt_settings):
mock_settings_global.SECRET_KEY = mock_jwt_settings.SECRET_KEY
mock_settings_global.ALGORITHM = mock_jwt_settings.ALGORITHM
mock_settings_global.ACCESS_TOKEN_EXPIRE_MINUTES = mock_jwt_settings.ACCESS_TOKEN_EXPIRE_MINUTES
subject = "test_user_wrong_type_refresh"
access_token = create_access_token(subject)
payload = verify_refresh_token(access_token)
assert payload is None

View File

@ -0,0 +1 @@

254
be/tests/crud/test_cost.py Normal file
View File

@ -0,0 +1,254 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict
import logging # For mocking the logger
from app.crud.cost import (
calculate_list_cost_summary,
calculate_suggested_settlements,
calculate_group_balance_summary
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.execute = AsyncMock()
session.get = AsyncMock()
return session
@pytest.fixture
def user_a():
return UserModel(id=1, name="User A", email="a@example.com")
@pytest.fixture
def user_b():
return UserModel(id=2, name="User B", email="b@example.com")
@pytest.fixture
def user_c():
return UserModel(id=3, name="User C", email="c@example.com")
@pytest.fixture
def basic_list_model(user_a):
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
@pytest.fixture
def group_model_with_members(user_a, user_b):
group = GroupModel(id=1, name="Test Group")
# Simulate user_associations for calculate_list_cost_summary if list is group-based
group.user_associations = [
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
]
return group
@pytest.fixture
def list_model_with_group(group_model_with_members):
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
# --- calculate_list_cost_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
basic_list_model.items = [] # Ensure no items
basic_list_model.group = None # Ensure it's a personal list
basic_list_model.creator = user_a
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
assert summary.list_id == basic_list_model.id
assert summary.total_list_cost == Decimal("0.00")
assert summary.num_participating_users == 1
assert summary.equal_share_per_user == Decimal("0.00")
assert len(summary.user_balances) == 1
assert summary.user_balances[0].user_id == user_a.id
assert summary.user_balances[0].balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
list_model_with_group.items = [item1, item2]
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
assert summary.total_list_cost == Decimal("5.00")
assert summary.num_participating_users == 2
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
assert balances_map[user_a.id].amount_due == Decimal("2.50")
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
assert balances_map[user_b.id].amount_due == Decimal("2.50")
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
with pytest.raises(ListNotFoundError):
await calculate_list_cost_summary(mock_db_session, 999)
# --- calculate_suggested_settlements Tests ---
@patch('app.crud.cost.logger') # Mock the logger used in the function
def test_calculate_suggested_settlements_simple_case(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
]
settlements = calculate_suggested_settlements(user_balances)
assert len(settlements) == 1
assert settlements[0].from_user_id == 1
assert settlements[0].to_user_id == 2
assert settlements[0].amount == Decimal("10.00")
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_multiple(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
]
settlements = calculate_suggested_settlements(user_balances)
# Expected: A owes B 10, A owes C 20
assert len(settlements) == 2
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
]
settlements = calculate_suggested_settlements(user_balances)
# A -> C 33.33, B -> C 33.33
assert len(settlements) == 2
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
assert total_settled_to_C == Decimal("66.66")
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
# --- calculate_group_balance_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
mock_db_session.get.return_value = group_model_with_members # For group fetch
# Mock group members query
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
# Mock expenses query (no expenses)
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = []
# Mock settlements query (no settlements)
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = []
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
assert summary.group_id == group_model_with_members.id
assert len(summary.user_balances) == 2
assert len(summary.suggested_settlements) == 0
for ub in summary.user_balances:
assert ub.net_balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
# Group members for this test: A, B, C
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
all_users = [user_a, user_b, user_c]
mock_db_session.get.return_value = group_model_with_members
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = all_users
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
# A: paid 90, share 30 -> +60
# B: paid 0, share 30 -> -30
# C: paid 0, share 30 -> -30
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
expense1.splits = [
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
]
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
# Settlements: B paid A 10
# A: received 10 -> +60 + 10 = +70
# B: paid 10 -> -30 - 10 = -40
# C: no settlement -> -30
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].net_balance == Decimal("70.00")
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
# Suggested settlements: B owes A 30 (40-10), C owes A 30
# Net balances: A: +70, B: -40, C: -30
# Algorithm should suggest: B -> A (40), C -> A (30)
assert len(summary.suggested_settlements) == 2
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
# Order might vary, check presence and amounts
assert (user_b.id, user_a.id) in settlement_map
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
assert (user_c.id, user_a.id) in settlement_map
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
mock_settlement_logger.warning.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
mock_db_session.get.return_value = None
with pytest.raises(GroupNotFoundError):
await calculate_group_balance_summary(mock_db_session, 999)
# TODO:
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
# - Test calculate_list_cost_summary with remainder distribution for equal share.
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
# - Test calculate_suggested_settlements with logger warning for remaining balances.
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).

View File

@ -0,0 +1,317 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional
from app.crud.expense import (
create_expense,
get_expense_by_id,
get_expenses_for_list,
get_expenses_for_group,
update_expense, # Assuming update_expense exists
delete_expense, # Assuming delete_expense exists
get_users_for_splitting # Helper, might test indirectly
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Item as ItemModel,
SplitTypeEnum
)
from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError
)
# General Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock() # create_expense uses flush
return session
@pytest.fixture
def basic_user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def basic_group_model():
group = GroupModel(id=1, name="Test Group")
# Simulate member_associations for get_users_for_splitting if needed directly
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
return group
@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)
@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
return ExpenseCreate(
description="Grocery run",
total_amount=Decimal("30.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=basic_list_model.id,
group_id=None, # Derived from list
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
return ExpenseCreate(
description="Movies",
total_amount=Decimal("50.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=None,
group_id=basic_group_model.id,
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
return ExpenseCreate(
description="Dinner",
total_amount=Decimal("100.00"),
split_type=SplitTypeEnum.EXACT_AMOUNTS,
group_id=basic_group_model.id,
paid_by_user_id=basic_user_model.id,
splits_in=[
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
]
)
@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
return ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded
# splits would be populated after creation usually
version=1
)
# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
# Setup group with members
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
mock_execute = AsyncMock()
mock_execute.scalars.return_value.first.return_value = basic_group_model
mock_db_session.execute.return_value = mock_execute
users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock get_users_for_splitting call within create_expense
# This is a bit tricky as it's an internal call. Patching is an option.
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
# mock_db_session.commit.assert_called_once() # create_expense does not commit itself
# mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance
# Check split amounts
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
# Mock the select for user validation in exact splits
mock_user_select_result = AsyncMock()
mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples
# To make it behave like scalars().all() that returns a list of IDs:
# We need to mock the scalars().all() part, or the whole execute chain for user validation.
# A simpler way for this specific case might be to mock the select for User.id
mock_execute_user_ids = AsyncMock()
# Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process
# It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}`
# Let's assume the select returns a list of Row objects or tuples with one element
mock_user_ids_result_proxy = MagicMock()
mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)])
mock_db_session.execute.return_value = mock_user_ids_result_proxy
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
assert created_expense.splits[0].owed_amount == Decimal("60.00")
assert created_expense.splits[1].owed_amount == Decimal("40.00")
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
mock_db_session.get.return_value = None # Payer not found
with pytest.raises(UserNotFoundError):
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)
@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
mock_db_session.get.return_value = basic_user_model # Payer found
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
expense_data.list_id = None
expense_data.group_id = None
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
await create_expense(mock_db_session, expense_data, 1)
# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_expense_model
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 1)
assert expense is not None
assert expense.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- Stubs for update_expense and delete_expense ---
# These will need more details once the actual implementation of update/delete is clear
# For example, how splits are handled on update, versioning, etc.
@pytest.mark.asyncio
async def test_update_expense_stub(mock_db_session):
# Placeholder: Test logic for update_expense will be more complex
# Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
# Also depends on what fields are updatable and how splits are managed.
expense_to_update = MagicMock(spec=ExpenseModel)
expense_to_update.version = 1
update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
# Simulate the update_expense function behavior
# For example, if it loads the expense, modifies, commits, refreshes:
# mock_db_session.get.return_value = expense_to_update
# updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
# assert updated_expense.description == "New description"
# mock_db_session.commit.assert_called_once()
# mock_db_session.refresh.assert_called_once()
pass # Replace with actual test logic
@pytest.mark.asyncio
async def test_delete_expense_stub(mock_db_session):
# Placeholder: Test logic for delete_expense
# Needs an existing expense object and mocking of delete/commit
# Also, consider implications (e.g., are splits deleted?)
expense_to_delete = MagicMock(spec=ExpenseModel)
expense_to_delete.id = 1
expense_to_delete.version = 1
# Simulate delete_expense behavior
# mock_db_session.get.return_value = expense_to_delete # If it re-fetches
# await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
# mock_db_session.delete.assert_called_once_with(expense_to_delete)
# mock_db_session.commit.assert_called_once()
pass # Replace with actual test logic
# TODO: Add more tests for create_expense covering:
# - List context success
# - Percentage, Shares, Item-based splits
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
# - Validation of list_id/group_id consistency
# - User not found in splits_in
# - Item not found for ITEM_BASED split
# TODO: Flesh out update_expense tests:
# - Success case
# - Version mismatch
# - Trying to update immutable fields
# - How splits are handled (recalculated, deleted/recreated, or not changeable)
# TODO: Flesh out delete_expense tests:
# - Success case
# - Version mismatch (if applicable)
# - Ensure associated splits are also deleted (cascade behavior)

270
be/tests/crud/test_group.py Normal file
View File

@ -0,0 +1,270 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import delete, func # For remove_user_from_group and get_group_member_count
from app.crud.group import (
create_group,
get_user_groups,
get_group_by_id,
is_user_member,
get_user_role_in_group,
add_user_to_group,
remove_user_from_group,
get_group_member_count,
check_group_membership,
check_user_role_in_group
)
from app.schemas.group import GroupCreate
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, UserRoleEnum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
# Patch begin_nested for SQLAlchemy 1.4+ if used, or just begin() if that's the pattern
# For simplicity, assuming `async with db.begin():` translates to db.begin() and db.commit()/rollback()
session.begin = AsyncMock() # Mock the begin call used in async with db.begin()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock() # For remove_user_from_group (if it uses session.delete)
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
return session
@pytest.fixture
def group_create_data():
return GroupCreate(name="Test Group")
@pytest.fixture
def creator_user_model():
return UserModel(id=1, name="Creator User", email="creator@example.com")
@pytest.fixture
def member_user_model():
return UserModel(id=2, name="Member User", email="member@example.com")
@pytest.fixture
def db_group_model(creator_user_model):
return GroupModel(id=1, name="Test Group", created_by_id=creator_user_model.id, creator=creator_user_model)
@pytest.fixture
def db_user_group_owner_assoc(db_group_model, creator_user_model):
return UserGroupModel(user_id=creator_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.owner, user=creator_user_model, group=db_group_model)
@pytest.fixture
def db_user_group_member_assoc(db_group_model, member_user_model):
return UserGroupModel(user_id=member_user_model.id, group_id=db_group_model.id, role=UserRoleEnum.member, user=member_user_model, group=db_group_model)
# --- create_group Tests ---
@pytest.mark.asyncio
async def test_create_group_success(mock_db_session, group_create_data, creator_user_model):
async def mock_refresh(instance):
instance.id = 1 # Simulate ID assignment by DB
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
created_group = await create_group(mock_db_session, group_create_data, creator_user_model.id)
assert mock_db_session.add.call_count == 2 # Group and UserGroup
mock_db_session.flush.assert_called() # Called multiple times
mock_db_session.refresh.assert_called_once_with(created_group)
assert created_group is not None
assert created_group.name == group_create_data.name
assert created_group.created_by_id == creator_user_model.id
# Further check if UserGroup was created correctly by inspecting mock_db_session.add calls or by fetching
@pytest.mark.asyncio
async def test_create_group_integrity_error(mock_db_session, group_create_data, creator_user_model):
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_group(mock_db_session, group_create_data, creator_user_model.id)
mock_db_session.rollback.assert_called_once() # Assuming rollback within the except block of create_group
# --- get_user_groups Tests ---
@pytest.mark.asyncio
async def test_get_user_groups_success(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_group_model]
mock_db_session.execute.return_value = mock_result
groups = await get_user_groups(mock_db_session, creator_user_model.id)
assert len(groups) == 1
assert groups[0].name == db_group_model.name
mock_db_session.execute.assert_called_once()
# --- get_group_by_id Tests ---
@pytest.mark.asyncio
async def test_get_group_by_id_found(mock_db_session, db_group_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_group_model
mock_db_session.execute.return_value = mock_result
group = await get_group_by_id(mock_db_session, db_group_model.id)
assert group is not None
assert group.id == db_group_model.id
# Add assertions for eager loaded members if applicable and mocked
@pytest.mark.asyncio
async def test_get_group_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
group = await get_group_by_id(mock_db_session, 999)
assert group is None
# --- is_user_member Tests ---
@pytest.mark.asyncio
async def test_is_user_member_true(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = 1 # Simulate UserGroup.id found
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, creator_user_model.id)
assert is_member is True
@pytest.mark.asyncio
async def test_is_user_member_false(mock_db_session, db_group_model, member_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = None # Simulate no UserGroup.id found
mock_db_session.execute.return_value = mock_result
is_member = await is_user_member(mock_db_session, db_group_model.id, member_user_model.id + 1) # Non-member
assert is_member is False
# --- get_user_role_in_group Tests ---
@pytest.mark.asyncio
async def test_get_user_role_in_group_owner(mock_db_session, db_group_model, creator_user_model):
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = UserRoleEnum.owner
mock_db_session.execute.return_value = mock_result
role = await get_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id)
assert role == UserRoleEnum.owner
# --- add_user_to_group Tests ---
@pytest.mark.asyncio
async def test_add_user_to_group_new_member(mock_db_session, db_group_model, member_user_model):
# First execute call for checking existing membership returns None
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_existing_check_result
async def mock_refresh_user_group(instance):
instance.id = 100 # Simulate ID for UserGroupModel
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh_user_group)
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.member)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert user_group_assoc is not None
assert user_group_assoc.user_id == member_user_model.id
assert user_group_assoc.group_id == db_group_model.id
assert user_group_assoc.role == UserRoleEnum.member
@pytest.mark.asyncio
async def test_add_user_to_group_already_member(mock_db_session, db_group_model, creator_user_model, db_user_group_owner_assoc):
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = db_user_group_owner_assoc # User is already a member
mock_db_session.execute.return_value = mock_existing_check_result
user_group_assoc = await add_user_to_group(mock_db_session, db_group_model.id, creator_user_model.id)
assert user_group_assoc is None
mock_db_session.add.assert_not_called()
# --- remove_user_from_group Tests ---
@pytest.mark.asyncio
async def test_remove_user_from_group_success(mock_db_session, db_group_model, member_user_model):
mock_delete_result = AsyncMock()
mock_delete_result.scalar_one_or_none.return_value = 1 # Simulate a row was deleted (returning ID)
mock_db_session.execute.return_value = mock_delete_result
removed = await remove_user_from_group(mock_db_session, db_group_model.id, member_user_model.id)
assert removed is True
# Assert that db.execute was called with a delete statement
# This requires inspecting the call args of mock_db_session.execute
# For simplicity, we check it was called. A deeper check would validate the SQL query itself.
mock_db_session.execute.assert_called_once()
# --- get_group_member_count Tests ---
@pytest.mark.asyncio
async def test_get_group_member_count_success(mock_db_session, db_group_model):
mock_count_result = AsyncMock()
mock_count_result.scalar_one.return_value = 5
mock_db_session.execute.return_value = mock_count_result
count = await get_group_member_count(mock_db_session, db_group_model.id)
assert count == 5
# --- check_group_membership Tests ---
@pytest.mark.asyncio
async def test_check_group_membership_is_member(mock_db_session, db_group_model, creator_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_result = AsyncMock()
mock_membership_result.scalar_one_or_none.return_value = 1 # User is a member
mock_db_session.execute.return_value = mock_membership_result
await check_group_membership(mock_db_session, db_group_model.id, creator_user_model.id)
# No exception means success
@pytest.mark.asyncio
async def test_check_group_membership_group_not_found(mock_db_session, creator_user_model):
mock_db_session.get.return_value = None # Group does not exist
with pytest.raises(GroupNotFoundError):
await check_group_membership(mock_db_session, 999, creator_user_model.id)
@pytest.mark.asyncio
async def test_check_group_membership_not_member(mock_db_session, db_group_model, member_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_result = AsyncMock()
mock_membership_result.scalar_one_or_none.return_value = None # User is not a member
mock_db_session.execute.return_value = mock_membership_result
with pytest.raises(GroupMembershipError):
await check_group_membership(mock_db_session, db_group_model.id, member_user_model.id)
# --- check_user_role_in_group Tests ---
@pytest.mark.asyncio
async def test_check_user_role_in_group_sufficient_role(mock_db_session, db_group_model, creator_user_model):
# Mock check_group_membership (implicitly called)
mock_db_session.get.return_value = db_group_model
mock_membership_check = AsyncMock()
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member
# Mock get_user_role_in_group
mock_role_check = AsyncMock()
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.owner
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
await check_user_role_in_group(mock_db_session, db_group_model.id, creator_user_model.id, UserRoleEnum.member)
# No exception means success
@pytest.mark.asyncio
async def test_check_user_role_in_group_insufficient_role(mock_db_session, db_group_model, member_user_model):
mock_db_session.get.return_value = db_group_model # Group exists
mock_membership_check = AsyncMock()
mock_membership_check.scalar_one_or_none.return_value = 1 # User is member (for check_group_membership call)
mock_role_check = AsyncMock()
mock_role_check.scalar_one_or_none.return_value = UserRoleEnum.member # User's actual role
mock_db_session.execute.side_effect = [mock_membership_check, mock_role_check]
with pytest.raises(GroupPermissionError):
await check_user_role_in_group(mock_db_session, db_group_model.id, member_user_model.id, UserRoleEnum.owner)
# TODO: Add tests for DB operational/SQLAlchemy errors for each function similar to create_group_integrity_error
# TODO: Test edge cases like trying to add user to non-existent group (should be caught by FK constraints or prior checks)

View File

@ -0,0 +1,174 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError # Assuming these might be raised
from datetime import datetime, timedelta, timezone
import secrets
from app.crud.invite import (
create_invite,
get_active_invite_by_code,
deactivate_invite,
MAX_CODE_GENERATION_ATTEMPTS
)
from app.models import Invite as InviteModel, User as UserModel, Group as GroupModel # For context
# No specific schemas for invite CRUD usually, but models are used.
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.execute = AsyncMock()
return session
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def user_model(): # Creator
return UserModel(id=1, name="Creator User")
@pytest.fixture
def db_invite_model(group_model, user_model):
return InviteModel(
id=1,
code="test_invite_code_123",
group_id=group_model.id,
created_by_id=user_model.id,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
is_active=True
)
# --- create_invite Tests ---
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe') # Patch secrets.token_urlsafe
async def test_create_invite_success_first_attempt(mock_token_urlsafe, mock_db_session, group_model, user_model):
generated_code = "unique_code_123"
mock_token_urlsafe.return_value = generated_code
# Mock DB execute for checking existing code (first attempt, no existing code)
mock_existing_check_result = AsyncMock()
mock_existing_check_result.scalar_one_or_none.return_value = None
mock_db_session.execute.return_value = mock_existing_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id, expires_in_days=5)
mock_token_urlsafe.assert_called_once_with(16)
mock_db_session.execute.assert_called_once() # For the uniqueness check
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(invite)
assert invite is not None
assert invite.code == generated_code
assert invite.group_id == group_model.id
assert invite.created_by_id == user_model.id
assert invite.is_active is True
assert invite.expires_at > datetime.now(timezone.utc) + timedelta(days=4) # Check expiry is roughly correct
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_success_after_collision(mock_token_urlsafe, mock_db_session, group_model, user_model):
colliding_code = "colliding_code"
unique_code = "finally_unique_code"
mock_token_urlsafe.side_effect = [colliding_code, unique_code] # First call collides, second is unique
# Mock DB execute for checking existing code
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Simulate collision (ID found)
mock_no_collision_check_result = AsyncMock()
mock_no_collision_check_result.scalar_one_or_none.return_value = None # No collision
mock_db_session.execute.side_effect = [mock_collision_check_result, mock_no_collision_check_result]
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert mock_token_urlsafe.call_count == 2
assert mock_db_session.execute.call_count == 2
assert invite is not None
assert invite.code == unique_code
@pytest.mark.asyncio
@patch('app.crud.invite.secrets.token_urlsafe')
async def test_create_invite_fails_after_max_attempts(mock_token_urlsafe, mock_db_session, group_model, user_model):
mock_token_urlsafe.return_value = "always_colliding_code"
mock_collision_check_result = AsyncMock()
mock_collision_check_result.scalar_one_or_none.return_value = 1 # Always collide
mock_db_session.execute.return_value = mock_collision_check_result
invite = await create_invite(mock_db_session, group_model.id, user_model.id)
assert invite is None
assert mock_token_urlsafe.call_count == MAX_CODE_GENERATION_ATTEMPTS
assert mock_db_session.execute.call_count == MAX_CODE_GENERATION_ATTEMPTS
mock_db_session.add.assert_not_called()
# --- get_active_invite_by_code Tests ---
@pytest.mark.asyncio
async def test_get_active_invite_by_code_found_active(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_invite_model
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is not None
assert invite.code == db_invite_model.code
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_active_invite_by_code_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, "non_existent_code")
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_inactive(mock_db_session, db_invite_model):
db_invite_model.is_active = False # Inactive
db_invite_model.expires_at = datetime.now(timezone.utc) + timedelta(days=1)
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
@pytest.mark.asyncio
async def test_get_active_invite_by_code_expired(mock_db_session, db_invite_model):
db_invite_model.is_active = True
db_invite_model.expires_at = datetime.now(timezone.utc) - timedelta(days=1) # Expired
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None # Should not be found by query
mock_db_session.execute.return_value = mock_result
invite = await get_active_invite_by_code(mock_db_session, db_invite_model.code)
assert invite is None
# --- deactivate_invite Tests ---
@pytest.mark.asyncio
async def test_deactivate_invite_success(mock_db_session, db_invite_model):
db_invite_model.is_active = True # Ensure it starts active
deactivated_invite = await deactivate_invite(mock_db_session, db_invite_model)
mock_db_session.add.assert_called_once_with(db_invite_model)
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_invite_model)
assert deactivated_invite.is_active is False
# It might be useful to test DB error cases (OperationalError, etc.) for each function
# if they have specific try-except blocks, but invite.py seems to rely on caller/framework for some of that.
# create_invite has its own DB interaction within the loop, so that's covered.
# get_active_invite_by_code and deactivate_invite are simpler DB ops.

184
be/tests/crud/test_item.py Normal file
View File

@ -0,0 +1,184 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from datetime import datetime, timezone
from app.crud.item import (
create_item,
get_items_by_list_id,
get_item_by_id,
update_item,
delete_item
)
from app.schemas.item import ItemCreate, ItemUpdate
from app.models import Item as ItemModel, User as UserModel, List as ListModel
from app.core.exceptions import (
ItemNotFoundError, # Not directly raised by CRUD but good for API layer tests
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock() # Though not directly used in item.py, good for consistency
session.flush = AsyncMock()
return session
@pytest.fixture
def item_create_data():
return ItemCreate(name="Test Item", quantity="1 pack")
@pytest.fixture
def item_update_data():
return ItemUpdate(name="Updated Test Item", quantity="2 packs", version=1, is_complete=False)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def list_model():
return ListModel(id=1, name="Test List")
@pytest.fixture
def db_item_model(list_model, user_model):
return ItemModel(
id=1,
name="Existing Item",
quantity="1 unit",
list_id=list_model.id,
added_by_id=user_model.id,
is_complete=False,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# --- create_item Tests ---
@pytest.mark.asyncio
async def test_create_item_success(mock_db_session, item_create_data, list_model, user_model):
async def mock_refresh(instance):
instance.id = 10 # Simulate ID assignment
instance.version = 1 # Simulate version init
instance.created_at = datetime.now(timezone.utc)
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
created_item = await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(created_item)
assert created_item is not None
assert created_item.name == item_create_data.name
assert created_item.list_id == list_model.id
assert created_item.added_by_id == user_model.id
assert created_item.is_complete is False
assert created_item.version == 1
@pytest.mark.asyncio
async def test_create_item_integrity_error(mock_db_session, item_create_data, list_model, user_model):
mock_db_session.flush.side_effect = IntegrityError("mock integrity error", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_item(mock_db_session, item_create_data, list_model.id, user_model.id)
mock_db_session.rollback.assert_called_once()
# --- get_items_by_list_id Tests ---
@pytest.mark.asyncio
async def test_get_items_by_list_id_success(mock_db_session, db_item_model, list_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_item_model]
mock_db_session.execute.return_value = mock_result
items = await get_items_by_list_id(mock_db_session, list_model.id)
assert len(items) == 1
assert items[0].id == db_item_model.id
mock_db_session.execute.assert_called_once()
# --- get_item_by_id Tests ---
@pytest.mark.asyncio
async def test_get_item_by_id_found(mock_db_session, db_item_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_item_model
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, db_item_model.id)
assert item is not None
assert item.id == db_item_model.id
@pytest.mark.asyncio
async def test_get_item_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
item = await get_item_by_id(mock_db_session, 999)
assert item is None
# --- update_item Tests ---
@pytest.mark.asyncio
async def test_update_item_success(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version # Match versions for successful update
item_update_data.name = "Newly Updated Name"
item_update_data.is_complete = True # Test completion logic
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.add.assert_called_once_with(db_item_model) # add is used for existing objects too
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_item_model)
assert updated_item.name == "Newly Updated Name"
assert updated_item.version == db_item_model.version # Check version increment logic in test
assert updated_item.is_complete is True
assert updated_item.completed_by_id == user_model.id
@pytest.mark.asyncio
async def test_update_item_version_conflict(mock_db_session, db_item_model, item_update_data, user_model):
item_update_data.version = db_item_model.version + 1 # Create a version mismatch
with pytest.raises(ConflictError):
await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
mock_db_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_update_item_set_incomplete(mock_db_session, db_item_model, item_update_data, user_model):
db_item_model.is_complete = True # Start as complete
db_item_model.completed_by_id = user_model.id
db_item_model.version = 1
item_update_data.version = 1
item_update_data.is_complete = False
item_update_data.name = db_item_model.name # No name change for this test
item_update_data.quantity = db_item_model.quantity
updated_item = await update_item(mock_db_session, db_item_model, item_update_data, user_model.id)
assert updated_item.is_complete is False
assert updated_item.completed_by_id is None
assert updated_item.version == 2
# --- delete_item Tests ---
@pytest.mark.asyncio
async def test_delete_item_success(mock_db_session, db_item_model):
result = await delete_item(mock_db_session, db_item_model)
assert result is None
mock_db_session.delete.assert_called_once_with(db_item_model)
mock_db_session.commit.assert_called_once() # Commit happens in the `async with db.begin()` context manager
@pytest.mark.asyncio
async def test_delete_item_db_error(mock_db_session, db_item_model):
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await delete_item(mock_db_session, db_item_model)
mock_db_session.rollback.assert_called_once()
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError) for each function.

259
be/tests/crud/test_list.py Normal file
View File

@ -0,0 +1,259 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError, SQLAlchemyError
from sqlalchemy.future import select
from sqlalchemy import func as sql_func # For get_list_status
from datetime import datetime, timezone
from app.crud.list import (
create_list,
get_lists_for_user,
get_list_by_id,
update_list,
delete_list,
check_list_permission,
get_list_status
)
from app.schemas.list import ListCreate, ListUpdate, ListStatus
from app.models import List as ListModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError
)
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock() # Used by check_list_permission via get_list_by_id
session.flush = AsyncMock()
return session
@pytest.fixture
def list_create_data():
return ListCreate(name="New Shopping List", description="Groceries for the week")
@pytest.fixture
def list_update_data():
return ListUpdate(name="Updated Shopping List", description="Weekend Groceries", version=1)
@pytest.fixture
def user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
@pytest.fixture
def db_list_personal_model(user_model):
return ListModel(
id=1, name="Personal List", created_by_id=user_model.id, creator=user_model,
version=1, updated_at=datetime.now(timezone.utc), items=[]
)
@pytest.fixture
def db_list_group_model(user_model, group_model):
return ListModel(
id=2, name="Group List", created_by_id=user_model.id, creator=user_model,
group_id=group_model.id, group=group_model, version=1, updated_at=datetime.now(timezone.utc), items=[]
)
# --- create_list Tests ---
@pytest.mark.asyncio
async def test_create_list_success(mock_db_session, list_create_data, user_model):
async def mock_refresh(instance):
instance.id = 100
instance.version = 1
instance.updated_at = datetime.now(timezone.utc)
return None
mock_db_session.refresh.return_value = None
mock_db_session.refresh.side_effect = mock_refresh
created_list = await create_list(mock_db_session, list_create_data, user_model.id)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_list.name == list_create_data.name
assert created_list.created_by_id == user_model.id
# --- get_lists_for_user Tests ---
@pytest.mark.asyncio
async def test_get_lists_for_user_mix(mock_db_session, user_model, db_list_personal_model, db_list_group_model):
# Simulate user is part of group for db_list_group_model
mock_group_ids_result = AsyncMock()
mock_group_ids_result.scalars.return_value.all.return_value = [db_list_group_model.group_id]
mock_lists_result = AsyncMock()
# Order should be personal list (created by user_id) then group list
mock_lists_result.scalars.return_value.all.return_value = [db_list_personal_model, db_list_group_model]
mock_db_session.execute.side_effect = [mock_group_ids_result, mock_lists_result]
lists = await get_lists_for_user(mock_db_session, user_model.id)
assert len(lists) == 2
assert db_list_personal_model in lists
assert db_list_group_model in lists
assert mock_db_session.execute.call_count == 2
# --- get_list_by_id Tests ---
@pytest.mark.asyncio
async def test_get_list_by_id_found_no_items(mock_db_session, db_list_personal_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=False)
assert found_list is not None
assert found_list.id == db_list_personal_model.id
# query options should not include selectinload for items
# (difficult to assert directly without inspecting query object in detail)
@pytest.mark.asyncio
async def test_get_list_by_id_found_with_items(mock_db_session, db_list_personal_model):
# Simulate items loaded for the list
db_list_personal_model.items = [ItemModel(id=1, name="Test Item")]
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_result
found_list = await get_list_by_id(mock_db_session, db_list_personal_model.id, load_items=True)
assert found_list is not None
assert len(found_list.items) == 1
# query options should include selectinload for items
# --- update_list Tests ---
@pytest.mark.asyncio
async def test_update_list_success(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version # Match version
updated_list = await update_list(mock_db_session, db_list_personal_model, list_update_data)
assert updated_list.name == list_update_data.name
assert updated_list.version == db_list_personal_model.version # version incremented in db_list_personal_model
mock_db_session.add.assert_called_once_with(db_list_personal_model)
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_list_personal_model)
@pytest.mark.asyncio
async def test_update_list_conflict(mock_db_session, db_list_personal_model, list_update_data):
list_update_data.version = db_list_personal_model.version + 1 # Version mismatch
with pytest.raises(ConflictError):
await update_list(mock_db_session, db_list_personal_model, list_update_data)
mock_db_session.rollback.assert_called_once()
# --- delete_list Tests ---
@pytest.mark.asyncio
async def test_delete_list_success(mock_db_session, db_list_personal_model):
await delete_list(mock_db_session, db_list_personal_model)
mock_db_session.delete.assert_called_once_with(db_list_personal_model)
mock_db_session.commit.assert_called_once() # from async with db.begin()
# --- check_list_permission Tests ---
@pytest.mark.asyncio
async def test_check_list_permission_creator_access_personal_list(mock_db_session, db_list_personal_model, user_model):
# get_list_by_id (called by check_list_permission) will mock execute
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_personal_model
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_personal_model.id, user_model.id)
assert ret_list.id == db_list_personal_model.id
@pytest.mark.asyncio
async def test_check_list_permission_group_member_access_group_list(mock_db_session, db_list_group_model, another_user_model, group_model):
# User `another_user_model` is not creator but member of the group
db_list_group_model.creator_id = user_model.id # Original creator is user_model
db_list_group_model.creator = user_model
# Mock get_list_by_id internal call
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
# Mock is_user_member call
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = True # another_user_model is a member
mock_db_session.execute.return_value = mock_list_fetch_result
ret_list = await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
assert ret_list.id == db_list_group_model.id
mock_is_member.assert_called_once_with(mock_db_session, group_id=group_model.id, user_id=another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_non_member_no_access_group_list(mock_db_session, db_list_group_model, another_user_model):
db_list_group_model.creator_id = user_model.id # Creator is not another_user_model
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = db_list_group_model
with patch('app.crud.list.is_user_member', new_callable=AsyncMock) as mock_is_member:
mock_is_member.return_value = False # another_user_model is NOT a member
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListPermissionError):
await check_list_permission(mock_db_session, db_list_group_model.id, another_user_model.id)
@pytest.mark.asyncio
async def test_check_list_permission_list_not_found(mock_db_session, user_model):
mock_list_fetch_result = AsyncMock()
mock_list_fetch_result.scalars.return_value.first.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_fetch_result
with pytest.raises(ListNotFoundError):
await check_list_permission(mock_db_session, 999, user_model.id)
# --- get_list_status Tests ---
@pytest.mark.asyncio
async def test_get_list_status_success(mock_db_session, db_list_personal_model):
list_updated_at = datetime.now(timezone.utc) - timezone.timedelta(hours=1)
item_updated_at = datetime.now(timezone.utc)
item_count = 5
db_list_personal_model.updated_at = list_updated_at
# Mock for ListModel.updated_at query
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = list_updated_at
# Mock for ItemModel status query
mock_item_status_result = AsyncMock()
# SQLAlchemy query for func.max and func.count returns a Row-like object or None
mock_item_status_row = MagicMock()
mock_item_status_row.latest_item_updated_at = item_updated_at
mock_item_status_row.item_count = item_count
mock_item_status_result.first.return_value = mock_item_status_row
mock_db_session.execute.side_effect = [mock_list_updated_result, mock_item_status_result]
status = await get_list_status(mock_db_session, db_list_personal_model.id)
assert status.list_updated_at == list_updated_at
assert status.latest_item_updated_at == item_updated_at
assert status.item_count == item_count
assert mock_db_session.execute.call_count == 2
@pytest.mark.asyncio
async def test_get_list_status_list_not_found(mock_db_session):
mock_list_updated_result = AsyncMock()
mock_list_updated_result.scalar_one_or_none.return_value = None # List not found
mock_db_session.execute.return_value = mock_list_updated_result
with pytest.raises(ListNotFoundError):
await get_list_status(mock_db_session, 999)
# TODO: Add more specific DB error tests (Operational, SQLAlchemyError, IntegrityError) for each function.
# TODO: Test check_list_permission with require_creator=True cases.

View File

@ -0,0 +1,250 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.future import select
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList
from app.crud.settlement import (
create_settlement,
get_settlement_by_id,
get_settlements_for_group,
get_settlements_involving_user,
update_settlement,
delete_settlement
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
return session
@pytest.fixture
def settlement_create_data():
return SettlementCreate(
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Test settlement"
)
@pytest.fixture
def settlement_update_data():
return SettlementUpdate(
description="Updated settlement description",
settlement_date=datetime.now(timezone.utc),
version=1 # Assuming version is required for update
)
@pytest.fixture
def db_settlement_model():
return SettlementModel(
id=1,
group_id=1,
paid_by_user_id=1,
paid_to_user_id=2,
amount=Decimal("10.50"),
settlement_date=datetime.now(timezone.utc),
description="Original settlement",
version=1, # Initial version
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
payer=UserModel(id=1, name="Payer User"),
payee=UserModel(id=2, name="Payee User"),
group=GroupModel(id=1, name="Test Group")
)
@pytest.fixture
def payer_user_model():
return UserModel(id=1, name="Payer User", email="payer@example.com")
@pytest.fixture
def payee_user_model():
return UserModel(id=2, name="Payee User", email="payee@example.com")
@pytest.fixture
def group_model():
return GroupModel(id=1, name="Test Group")
# Tests for create_settlement
@pytest.mark.asyncio
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] # Order of gets
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_settlement is not None
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
@pytest.mark.asyncio
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
with pytest.raises(UserNotFoundError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payee" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
with pytest.raises(GroupNotFoundError):
await create_settlement(mock_db_session, settlement_create_data, 1)
@pytest.mark.asyncio
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
@pytest.mark.asyncio
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await create_settlement(mock_db_session, settlement_create_data, 1)
assert "Failed to save settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for get_settlement_by_id
@pytest.mark.asyncio
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = db_settlement_model
settlement = await get_settlement_by_id(mock_db_session, 1)
assert settlement is not None
assert settlement.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlement_by_id_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
settlement = await get_settlement_by_id(mock_db_session, 999)
assert settlement is None
# Tests for get_settlements_for_group
@pytest.mark.asyncio
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
assert len(settlements) == 1
assert settlements[0].group_id == 1
mock_db_session.execute.assert_called_once()
# Tests for get_settlements_involving_user
@pytest.mark.asyncio
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
assert len(settlements) == 1
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [db_settlement_model]
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
assert len(settlements) == 1
# More specific assertions about the query would require deeper mocking of SQLAlchemy query construction
mock_db_session.execute.assert_called_once()
# Tests for update_settlement
@pytest.mark.asyncio
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
# Ensure settlement_update_data.version matches db_settlement_model.version
settlement_update_data.version = db_settlement_model.version
# Mock datetime.now()
fixed_datetime_now = datetime.now(timezone.utc)
with patch('app.crud.settlement.datetime', wraps=datetime) as mock_datetime:
mock_datetime.now.return_value = fixed_datetime_now
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
mock_db_session.commit.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert updated_settlement.description == settlement_update_data.description
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
assert updated_settlement.version == db_settlement_model.version + 1 # Version incremented
assert updated_settlement.updated_at == fixed_datetime_now
@pytest.mark.asyncio
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version + 1 # Mismatched version
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
assert "version does not match" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
assert "Field 'amount' cannot be updated" in str(excinfo.value)
@pytest.mark.asyncio
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
settlement_update_data.version = db_settlement_model.version
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
assert "Failed to update settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()
# Tests for delete_settlement
@pytest.mark.asyncio
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
mock_db_session.delete.assert_called_once_with(db_settlement_model)
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version + 1)
assert "Expected version" in str(excinfo.value)
assert "does not match current version" in str(excinfo.value)
mock_db_session.delete.assert_not_called()
@pytest.mark.asyncio
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
mock_db_session.commit.side_effect = Exception("DB commit failed")
with pytest.raises(InvalidOperationError) as excinfo:
await delete_settlement(mock_db_session, db_settlement_model)
assert "Failed to delete settlement" in str(excinfo.value)
mock_db_session.rollback.assert_called_once()

117
be/tests/crud/test_user.py Normal file
View File

@ -0,0 +1,117 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.exc import IntegrityError, OperationalError
from app.crud.user import get_user_by_email, create_user
from app.schemas.user import UserCreate
from app.models import User as UserModel
from app.core.exceptions import (
UserCreationError,
EmailAlreadyRegisteredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError
)
# Fixtures
@pytest.fixture
def mock_db_session():
return AsyncMock()
@pytest.fixture
def user_create_data():
return UserCreate(email="test@example.com", password="password123", name="Test User")
@pytest.fixture
def existing_user_data():
return UserModel(id=1, email="exists@example.com", password_hash="hashed_password", name="Existing User")
# Tests for get_user_by_email
@pytest.mark.asyncio
async def test_get_user_by_email_found(mock_db_session, existing_user_data):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = existing_user_data
user = await get_user_by_email(mock_db_session, "exists@example.com")
assert user is not None
assert user.email == "exists@example.com"
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
user = await get_user_by_email(mock_db_session, "nonexistent@example.com")
assert user is None
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_by_email_db_connection_error(mock_db_session):
mock_db_session.execute.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await get_user_by_email(mock_db_session, "test@example.com")
@pytest.mark.asyncio
async def test_get_user_by_email_db_query_error(mock_db_session):
# Simulate a generic SQLAlchemyError that is not OperationalError
mock_db_session.execute.side_effect = IntegrityError("mock_sql_error", "params", "orig") # Using IntegrityError as an example of SQLAlchemyError
with pytest.raises(DatabaseQueryError):
await get_user_by_email(mock_db_session, "test@example.com")
# Tests for create_user
@pytest.mark.asyncio
async def test_create_user_success(mock_db_session, user_create_data):
# The actual user object returned would be created by SQLAlchemy based on db_user
# We mock the process: db.add is called, then db.flush, then db.refresh updates db_user
async def mock_refresh(user_model_instance):
user_model_instance.id = 1 # Simulate DB assigning an ID
# Simulate other db-generated fields if necessary
return None
mock_db_session.refresh = AsyncMock(side_effect=mock_refresh)
mock_db_session.flush = AsyncMock()
mock_db_session.add = MagicMock()
created_user = await create_user(mock_db_session, user_create_data)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_user is not None
assert created_user.email == user_create_data.email
assert created_user.name == user_create_data.name
assert hasattr(created_user, 'id') # Check if ID was assigned (simulated by mock_refresh)
# Password hash check would be more involved, ensure hash_password was called correctly
# For now, we assume hash_password works as intended and is tested elsewhere.
@pytest.mark.asyncio
async def test_create_user_email_already_registered(mock_db_session, user_create_data):
mock_db_session.flush.side_effect = IntegrityError("mock error (unique constraint)", "params", "orig")
with pytest.raises(EmailAlreadyRegisteredError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_integrity_error_not_unique(mock_db_session, user_create_data):
# Simulate an IntegrityError that is not related to a unique constraint
mock_db_session.flush.side_effect = IntegrityError("mock error (not unique constraint)", "params", "orig")
with pytest.raises(DatabaseIntegrityError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_connection_error(mock_db_session, user_create_data):
mock_db_session.begin.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
# also test OperationalError on flush
mock_db_session.begin.side_effect = None # reset side effect
mock_db_session.flush.side_effect = OperationalError("mock_op_error", "params", "orig")
with pytest.raises(DatabaseConnectionError):
await create_user(mock_db_session, user_create_data)
@pytest.mark.asyncio
async def test_create_user_db_transaction_error(mock_db_session, user_create_data):
# Simulate a generic SQLAlchemyError on flush that is not IntegrityError or OperationalError
mock_db_session.flush.side_effect = UserCreationError("Simulated non-specific SQLAlchemyError") # Or any other SQLAlchemyError
with pytest.raises(DatabaseTransactionError):
await create_user(mock_db_session, user_create_data)