From e3024ccd077dfc36b750aade1da2d7f30f6b2754 Mon Sep 17 00:00:00 2001 From: Mohamad Date: Thu, 8 May 2025 14:31:34 +0200 Subject: [PATCH] remove old cost splitter --- be/app/api/v1/endpoints/costs.py | 236 ++++++++++++++++++++++----- be/app/config.py | 1 + be/app/crud/cost.py | 266 ------------------------------- be/tests/crud/test_cost.py | 254 ----------------------------- 4 files changed, 195 insertions(+), 562 deletions(-) delete mode 100644 be/app/crud/cost.py delete mode 100644 be/tests/crud/test_cost.py diff --git a/be/app/api/v1/endpoints/costs.py b/be/app/api/v1/endpoints/costs.py index a16d0ab..ff6e158 100644 --- a/be/app/api/v1/endpoints/costs.py +++ b/be/app/api/v1/endpoints/costs.py @@ -4,13 +4,25 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, selectinload +from decimal import Decimal, ROUND_HALF_UP from app.database import get_db from app.api.dependencies import get_current_user -from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model +from app.models import ( + User as UserModel, + Group as GroupModel, + List as ListModel, + Expense as ExpenseModel, + Item as ItemModel, + UserGroup as UserGroupModel, + SplitTypeEnum, + ExpenseSplit as ExpenseSplitModel, + Settlement as SettlementModel +) from app.schemas.cost import ListCostSummary, GroupBalanceSummary -from app.crud import cost as crud_cost -from app.crud import list as crud_list # For permission checking +from app.schemas.expense import ExpenseCreate +from app.crud import list as crud_list +from app.crud import expense as crud_expense from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError logger = logging.getLogger(__name__) @@ -47,28 +59,117 @@ async def get_list_cost_summary( await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) except ListPermissionError as e: logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}") - raise # Re-raise the original exception to be handled by FastAPI + raise except ListNotFoundError as e: logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}") - raise # Re-raise + raise - # 2. Calculate the cost summary - try: - cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id) - logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}") - return cost_summary - except ListNotFoundError as e: - logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}") - # This might be redundant if check_list_permission already confirmed list existence, - # but calculate_list_cost_summary also fetches the list. - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except UserNotFoundError as e: - logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}") - # This indicates a data integrity issue (e.g., list creator or item adder missing) - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except Exception as e: - logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.") + # 2. Get the list with its items and users + list_result = await db.execute( + select(ListModel) + .options( + selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)), + selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))), + selectinload(ListModel.creator) + ) + .where(ListModel.id == list_id) + ) + db_list = list_result.scalars().first() + if not db_list: + raise ListNotFoundError(list_id) + + # 3. Get or create an expense for this list + expense_result = await db.execute( + select(ExpenseModel) + .where(ExpenseModel.list_id == list_id) + .options(selectinload(ExpenseModel.splits)) + ) + db_expense = expense_result.scalars().first() + + if not db_expense: + # Create a new expense for this list + total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0")) + if total_amount == Decimal("0"): + return ListCostSummary( + list_id=db_list.id, + list_name=db_list.name, + total_list_cost=Decimal("0.00"), + num_participating_users=0, + equal_share_per_user=Decimal("0.00"), + user_balances=[] + ) + + # Create expense with ITEM_BASED split type + expense_in = ExpenseCreate( + description=f"Cost summary for list {db_list.name}", + total_amount=total_amount, + list_id=list_id, + split_type=SplitTypeEnum.ITEM_BASED, + paid_by_user_id=current_user.id # Use current user as payer for now + ) + db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in) + + # 4. Calculate cost summary from expense splits + participating_users = set() + user_items_added_value = {} + total_list_cost = Decimal("0.00") + + # Get all users who added items + for item in db_list.items: + if item.price is not None and item.price > Decimal("0") and item.added_by_user: + participating_users.add(item.added_by_user) + user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price + total_list_cost += item.price + + # Get all users from expense splits + for split in db_expense.splits: + if split.user: + participating_users.add(split.user) + + num_participating_users = len(participating_users) + if num_participating_users == 0: + return ListCostSummary( + list_id=db_list.id, + list_name=db_list.name, + total_list_cost=Decimal("0.00"), + num_participating_users=0, + equal_share_per_user=Decimal("0.00"), + user_balances=[] + ) + + equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + remainder = total_list_cost - (equal_share_per_user * num_participating_users) + + user_balances = [] + first_user_processed = False + for user in participating_users: + items_added = user_items_added_value.get(user.id, Decimal("0.00")) + current_user_share = equal_share_per_user + if not first_user_processed and remainder != Decimal("0"): + current_user_share += remainder + first_user_processed = True + + balance = items_added - current_user_share + user_identifier = user.name if user.name else user.email + user_balances.append( + UserCostShare( + user_id=user.id, + user_identifier=user_identifier, + items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + ) + ) + + user_balances.sort(key=lambda x: x.user_identifier) + return ListCostSummary( + list_id=db_list.id, + list_name=db_list.name, + total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + num_participating_users=num_participating_users, + equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + user_balances=user_balances + ) @router.get( "/groups/{group_id}/balance-summary", @@ -92,11 +193,7 @@ async def get_group_balance_summary( """ logger.info(f"User {current_user.email} requesting balance summary for group {group_id}") - # 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar) - # Assuming a function like this exists in app.crud.group or we add it. - # For now, let's placeholder this check logic. - # await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id) - # A simpler check for now: fetch the group and see if user is part of member_associations + # 1. Verify user is a member of the target group group_check = await db.execute( select(GroupModel) .options(selectinload(GroupModel.member_associations)) @@ -109,20 +206,75 @@ async def get_group_balance_summary( user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations) if not user_is_member: - # If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}") - # 2. Calculate the group balance summary - try: - balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id) - logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}") - return balance_summary - except GroupNotFoundError as e: - logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) - except UserNotFoundError as e: # Should not happen if group members are correctly fetched - logger.error(f"User not found during balance summary for group {group_id}: {str(e)}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.") - except Exception as e: - logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.") \ No newline at end of file + # 2. Get all expenses and settlements for the group + expenses_result = await db.execute( + select(ExpenseModel) + .where(ExpenseModel.group_id == group_id) + .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)) + ) + expenses = expenses_result.scalars().all() + + settlements_result = await db.execute( + select(SettlementModel) + .where(SettlementModel.group_id == group_id) + .options( + selectinload(SettlementModel.paid_by_user), + selectinload(SettlementModel.paid_to_user) + ) + ) + settlements = settlements_result.scalars().all() + + # 3. Calculate user balances + user_balances_data = {} + for assoc in db_group_for_check.member_associations: + if assoc.user: + user_balances_data[assoc.user.id] = UserBalanceDetail( + user_id=assoc.user.id, + user_identifier=assoc.user.name if assoc.user.name else assoc.user.email + ) + + # Process expenses + for expense in expenses: + if expense.paid_by_user_id in user_balances_data: + user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount + + for split in expense.splits: + if split.user_id in user_balances_data: + user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount + + # Process settlements + for settlement in settlements: + if settlement.paid_by_user_id in user_balances_data: + user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount + if settlement.paid_to_user_id in user_balances_data: + user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount + + # Calculate net balances + final_user_balances = [] + for user_id, data in user_balances_data.items(): + data.net_balance = ( + data.total_paid_for_expenses + data.total_settlements_received + ) - (data.total_share_of_expenses + data.total_settlements_paid) + + data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) + + final_user_balances.append(data) + + # Sort by user identifier + final_user_balances.sort(key=lambda x: x.user_identifier) + + # Calculate suggested settlements + suggested_settlements = calculate_suggested_settlements(final_user_balances) + + return GroupBalanceSummary( + group_id=db_group_for_check.id, + group_name=db_group_for_check.name, + user_balances=final_user_balances, + suggested_settlements=suggested_settlements + ) \ No newline at end of file diff --git a/be/app/config.py b/be/app/config.py index 8333247..b5ef157 100644 --- a/be/app/config.py +++ b/be/app/config.py @@ -60,6 +60,7 @@ Organic Bananas CORS_ORIGINS: list[str] = [ "http://localhost:5174", "http://localhost:8000", + "http://localhost:9000", # Add your deployed frontend URL here later # "https://your-frontend-domain.com", ] diff --git a/be/app/crud/cost.py b/be/app/crud/cost.py deleted file mode 100644 index d874ba2..0000000 --- a/be/app/crud/cost.py +++ /dev/null @@ -1,266 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload, joinedload -from sqlalchemy import func as sql_func, or_ -from decimal import Decimal, ROUND_HALF_UP -from typing import Dict, Optional, Sequence, List as PyList -from collections import defaultdict -import logging - -from app.models import ( - List as ListModel, - Item as ItemModel, - User as UserModel, - UserGroup as UserGroupModel, - Group as GroupModel, - Expense as ExpenseModel, - ExpenseSplit as ExpenseSplitModel, - Settlement as SettlementModel -) -from app.schemas.cost import ( - ListCostSummary, - UserCostShare, - GroupBalanceSummary, - UserBalanceDetail, - SuggestedSettlement -) -from app.core.exceptions import ( - ListNotFoundError, - UserNotFoundError, - GroupNotFoundError, - InvalidOperationError -) - -logger = logging.getLogger(__name__) - -async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary: - """ - Calculates the cost summary for a given list based purely on item prices and who added them. - This is a simpler calculation and does not involve the Expense/Settlement system. - """ - list_result = await db.execute( - select(ListModel) - .options( - selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)), - selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))), - selectinload(ListModel.creator) - ) - .where(ListModel.id == list_id) - ) - db_list: Optional[ListModel] = list_result.scalars().first() - - if not db_list: - raise ListNotFoundError(list_id) - - participating_users_map: Dict[int, UserModel] = {} - if db_list.group: - for ug_assoc in db_list.group.user_associations: - if ug_assoc.user: - participating_users_map[ug_assoc.user.id] = ug_assoc.user - elif db_list.creator: # Personal list - participating_users_map[db_list.creator.id] = db_list.creator - - # Include all users who added items with prices, even if not in the primary context (group/creator) - for item in db_list.items: - if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map: - participating_users_map[item.added_by_user.id] = item.added_by_user - - num_participating_users = len(participating_users_map) - if num_participating_users == 0: - return ListCostSummary( - list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"), - num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] - ) - - total_list_cost = Decimal("0.00") - user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal) - - for item in db_list.items: - if item.price is not None and item.price > Decimal("0"): - total_list_cost += item.price - if item.added_by_id in participating_users_map: - user_items_added_value[item.added_by_id] += item.price - - equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - remainder = total_list_cost - (equal_share_per_user * num_participating_users) - - user_balances: PyList[UserCostShare] = [] - first_user_processed = False - for user_id, user_obj in participating_users_map.items(): - items_added = user_items_added_value.get(user_id, Decimal("0.00")) - current_user_share = equal_share_per_user - if not first_user_processed and remainder != Decimal("0"): - current_user_share += remainder - first_user_processed = True - - balance = items_added - current_user_share - user_identifier = user_obj.name if user_obj.name else user_obj.email - user_balances.append( - UserCostShare( - user_id=user_id, user_identifier=user_identifier, - items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - ) - ) - user_balances.sort(key=lambda x: x.user_identifier) - return ListCostSummary( - list_id=db_list.id, list_name=db_list.name, - total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - num_participating_users=num_participating_users, - equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - user_balances=user_balances - ) - -# --- Helper for Settlement Suggestions --- -def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]: - """ - Calculates a list of suggested settlements to resolve group debts. - Uses a greedy algorithm to minimize the number of transactions. - Input: List of UserBalanceDetail objects with calculated net_balances. - Output: List of SuggestedSettlement objects. - """ - # Use a small tolerance for floating point comparisons with Decimal - tolerance = Decimal("0.001") - - debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance) - creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True) - - settlements: PyList[SuggestedSettlement] = [] - debtor_idx = 0 - creditor_idx = 0 - - # Create mutable copies of balances to track remaining amounts - debtor_balances = {d.user_id: d.net_balance for d in debtors} - creditor_balances = {c.user_id: c.net_balance for c in creditors} - user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances} - - while debtor_idx < len(debtors) and creditor_idx < len(creditors): - debtor = debtors[debtor_idx] - creditor = creditors[creditor_idx] - - debtor_remaining = debtor_balances[debtor.user_id] - creditor_remaining = creditor_balances[creditor.user_id] - - # Amount to transfer is the minimum of what debtor owes and what creditor is owed - transfer_amount = min(abs(debtor_remaining), creditor_remaining) - transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - - if transfer_amount > tolerance: # Only record meaningful transfers - settlements.append(SuggestedSettlement( - from_user_id=debtor.user_id, - from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"), - to_user_id=creditor.user_id, - to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"), - amount=transfer_amount - )) - - # Update remaining balances - debtor_balances[debtor.user_id] += transfer_amount - creditor_balances[creditor.user_id] -= transfer_amount - - # Move to next debtor if current one is settled (or very close) - if abs(debtor_balances[debtor.user_id]) < tolerance: - debtor_idx += 1 - - # Move to next creditor if current one is settled (or very close) - if creditor_balances[creditor.user_id] < tolerance: - creditor_idx += 1 - - # Log if lists aren't empty - indicates potential imbalance or rounding issue - if debtor_idx < len(debtors) or creditor_idx < len(creditors): - # Calculate remaining balances for logging - remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance) - remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance) - logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.") - - return settlements - -# --- NEW: Detailed Group Balance Summary --- -async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary: - """ - Calculates a detailed balance summary for all users in a group, - considering all expenses, splits, and settlements within that group. - Also calculates suggested settlements. - """ - group = await db.get(GroupModel, group_id) - if not group: - raise GroupNotFoundError(group_id) - - # 1. Get all group members - group_members_result = await db.execute( - select(UserModel) - .join(UserGroupModel, UserModel.id == UserGroupModel.user_id) - .where(UserGroupModel.group_id == group_id) - ) - group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()} - if not group_members: - return GroupBalanceSummary( - group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[] - ) - - user_balances_data: Dict[int, UserBalanceDetail] = {} - for user_id, user_obj in group_members.items(): - user_balances_data[user_id] = UserBalanceDetail( - user_id=user_id, - user_identifier=user_obj.name if user_obj.name else user_obj.email - ) - - overall_total_expenses = Decimal("0.00") - overall_total_settlements = Decimal("0.00") - - # 2. Process Expenses and ExpenseSplits for the group - expenses_result = await db.execute( - select(ExpenseModel) - .where(ExpenseModel.group_id == group_id) - .options(selectinload(ExpenseModel.splits)) - ) - for expense in expenses_result.scalars().all(): - overall_total_expenses += expense.total_amount - if expense.paid_by_user_id in user_balances_data: - user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount - - for split in expense.splits: - if split.user_id in user_balances_data: - user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount - - # 3. Process Settlements for the group - settlements_result = await db.execute( - select(SettlementModel).where(SettlementModel.group_id == group_id) - ) - for settlement in settlements_result.scalars().all(): - overall_total_settlements += settlement.amount - if settlement.paid_by_user_id in user_balances_data: - user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount - if settlement.paid_to_user_id in user_balances_data: - user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount - - # 4. Calculate net balances and prepare final list - final_user_balances: PyList[UserBalanceDetail] = [] - for user_id in group_members.keys(): - data = user_balances_data[user_id] - data.net_balance = ( - data.total_paid_for_expenses + data.total_settlements_received - ) - (data.total_share_of_expenses + data.total_settlements_paid) - - data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - - final_user_balances.append(data) - - final_user_balances.sort(key=lambda x: x.user_identifier) - - # 5. Calculate suggested settlements (NEW) - suggested_settlements = calculate_suggested_settlements(final_user_balances) - - return GroupBalanceSummary( - group_id=group.id, - group_name=group.name, - overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - user_balances=final_user_balances, - suggested_settlements=suggested_settlements # Add suggestions to response - ) \ No newline at end of file diff --git a/be/tests/crud/test_cost.py b/be/tests/crud/test_cost.py deleted file mode 100644 index 0cfd839..0000000 --- a/be/tests/crud/test_cost.py +++ /dev/null @@ -1,254 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from decimal import Decimal, ROUND_HALF_UP -from typing import List as PyList, Dict -import logging # For mocking the logger - -from app.crud.cost import ( - calculate_list_cost_summary, - calculate_suggested_settlements, - calculate_group_balance_summary -) -from app.schemas.cost import ( - ListCostSummary, - UserCostShare, - GroupBalanceSummary, - UserBalanceDetail, - SuggestedSettlement -) -from app.models import ( - List as ListModel, - Item as ItemModel, - User as UserModel, - Group as GroupModel, - UserGroup as UserGroupModel, - Expense as ExpenseModel, - ExpenseSplit as ExpenseSplitModel, - Settlement as SettlementModel -) -from app.core.exceptions import ListNotFoundError, GroupNotFoundError - -# Fixtures -@pytest.fixture -def mock_db_session(): - session = AsyncMock() - session.execute = AsyncMock() - session.get = AsyncMock() - return session - -@pytest.fixture -def user_a(): - return UserModel(id=1, name="User A", email="a@example.com") - -@pytest.fixture -def user_b(): - return UserModel(id=2, name="User B", email="b@example.com") - -@pytest.fixture -def user_c(): - return UserModel(id=3, name="User C", email="c@example.com") - -@pytest.fixture -def basic_list_model(user_a): - return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[]) - -@pytest.fixture -def group_model_with_members(user_a, user_b): - group = GroupModel(id=1, name="Test Group") - # Simulate user_associations for calculate_list_cost_summary if list is group-based - group.user_associations = [ - UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a), - UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b) - ] - return group - -@pytest.fixture -def list_model_with_group(group_model_with_members): - return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[]) - -# --- calculate_list_cost_summary Tests --- -@pytest.mark.asyncio -async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model - basic_list_model.items = [] # Ensure no items - basic_list_model.group = None # Ensure it's a personal list - basic_list_model.creator = user_a - - summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id) - - assert summary.list_id == basic_list_model.id - assert summary.total_list_cost == Decimal("0.00") - assert summary.num_participating_users == 1 - assert summary.equal_share_per_user == Decimal("0.00") - assert len(summary.user_balances) == 1 - assert summary.user_balances[0].user_id == user_a.id - assert summary.user_balances[0].balance == Decimal("0.00") - -@pytest.mark.asyncio -async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b): - item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id) - item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id) - list_model_with_group.items = [item1, item2] - - mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group - - summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id) - - assert summary.total_list_cost == Decimal("5.00") - assert summary.num_participating_users == 2 - assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2 - - balances_map = {ub.user_id: ub for ub in summary.user_balances} - assert balances_map[user_a.id].items_added_value == Decimal("3.00") - assert balances_map[user_a.id].amount_due == Decimal("2.50") - assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50 - - assert balances_map[user_b.id].items_added_value == Decimal("2.00") - assert balances_map[user_b.id].amount_due == Decimal("2.50") - assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50 - -@pytest.mark.asyncio -async def test_calculate_list_cost_summary_list_not_found(mock_db_session): - mock_db_session.execute.return_value.scalars.return_value.first.return_value = None - with pytest.raises(ListNotFoundError): - await calculate_list_cost_summary(mock_db_session, 999) - -# --- calculate_suggested_settlements Tests --- -@patch('app.crud.cost.logger') # Mock the logger used in the function -def test_calculate_suggested_settlements_simple_case(mock_logger): - user_balances = [ - UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")), - UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")), - ] - settlements = calculate_suggested_settlements(user_balances) - assert len(settlements) == 1 - assert settlements[0].from_user_id == 1 - assert settlements[0].to_user_id == 2 - assert settlements[0].amount == Decimal("10.00") - mock_logger.warning.assert_not_called() - -@patch('app.crud.cost.logger') -def test_calculate_suggested_settlements_multiple(mock_logger): - user_balances = [ - UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")), - UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")), - UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")), - ] - settlements = calculate_suggested_settlements(user_balances) - # Expected: A owes B 10, A owes C 20 - assert len(settlements) == 2 - s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements} - assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first) - assert s_map[(1, 2)] == Decimal("10.00") # A -> B - mock_logger.warning.assert_not_called() - -@patch('app.crud.cost.logger') -def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger): - user_balances = [ - UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")), - UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")), - UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")), - ] - settlements = calculate_suggested_settlements(user_balances) - # A -> C 33.33, B -> C 33.33 - assert len(settlements) == 2 - total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3) - assert total_settled_to_C == Decimal("66.66") - mock_logger.warning.assert_not_called() # Assuming exact match after quantization - -# --- calculate_group_balance_summary Tests --- -@pytest.mark.asyncio -async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b): - mock_db_session.get.return_value = group_model_with_members # For group fetch - - # Mock group members query - mock_members_result = AsyncMock() - mock_members_result.scalars.return_value.all.return_value = [user_a, user_b] - - # Mock expenses query (no expenses) - mock_expenses_result = AsyncMock() - mock_expenses_result.scalars.return_value.all.return_value = [] - - # Mock settlements query (no settlements) - mock_settlements_result = AsyncMock() - mock_settlements_result.scalars.return_value.all.return_value = [] - - mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result] - - summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id) - - assert summary.group_id == group_model_with_members.id - assert len(summary.user_balances) == 2 - assert len(summary.suggested_settlements) == 0 - for ub in summary.user_balances: - assert ub.net_balance == Decimal("0.00") - -@pytest.mark.asyncio -async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c): - # Group members for this test: A, B, C - group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c)) - all_users = [user_a, user_b, user_c] - - mock_db_session.get.return_value = group_model_with_members - - mock_members_result = AsyncMock() - mock_members_result.scalars.return_value.all.return_value = all_users - - # Expenses: A paid 90, split equally (A, B, C -> 30 each) - # A: paid 90, share 30 -> +60 - # B: paid 0, share 30 -> -30 - # C: paid 0, share 30 -> -30 - expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00")) - expense1.splits = [ - ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")), - ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")), - ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")), - ] - mock_expenses_result = AsyncMock() - mock_expenses_result.scalars.return_value.all.return_value = [expense1] - - # Settlements: B paid A 10 - # A: received 10 -> +60 + 10 = +70 - # B: paid 10 -> -30 - 10 = -40 - # C: no settlement -> -30 - settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00")) - mock_settlements_result = AsyncMock() - mock_settlements_result.scalars.return_value.all.return_value = [settlement1] - - mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result] - - with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements - summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id) - - balances_map = {ub.user_id: ub for ub in summary.user_balances} - assert balances_map[user_a.id].net_balance == Decimal("70.00") - assert balances_map[user_b.id].net_balance == Decimal("-40.00") - assert balances_map[user_c.id].net_balance == Decimal("-30.00") - - # Suggested settlements: B owes A 30 (40-10), C owes A 30 - # Net balances: A: +70, B: -40, C: -30 - # Algorithm should suggest: B -> A (40), C -> A (30) - assert len(summary.suggested_settlements) == 2 - settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements} - - # Order might vary, check presence and amounts - assert (user_b.id, user_a.id) in settlement_map - assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00") - - assert (user_c.id, user_a.id) in settlement_map - assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00") - mock_settlement_logger.warning.assert_not_called() - -@pytest.mark.asyncio -async def test_calculate_group_balance_summary_group_not_found(mock_db_session): - mock_db_session.get.return_value = None - with pytest.raises(GroupNotFoundError): - await calculate_group_balance_summary(mock_db_session, 999) - -# TODO: -# - Test calculate_list_cost_summary with item added by user not in group/not creator. -# - Test calculate_list_cost_summary with remainder distribution for equal share. -# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc. -# - Test calculate_suggested_settlements with logger warning for remaining balances. -# - Test calculate_group_balance_summary with more complex expense/settlement scenarios. -# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean). \ No newline at end of file