from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from decimal import Decimal, ROUND_HALF_UP from typing import List as PyList, Dict, Set from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel from app.schemas.cost import ListCostSummary, UserCostShare from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary: """ Calculates the cost summary for a given list, splitting costs equally among relevant users (group members if list is in a group, or creator if personal). """ # 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members) list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)), selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))) ) .where(ListModel.id == list_id) ) db_list: Optional[ListModel] = list_result.scalars().first() if not db_list: raise ListNotFoundError(list_id) # 2. Determine participating users participating_users: Dict[int, UserModel] = {} if db_list.group: # If list is part of a group, all group members participate for ug_assoc in db_list.group.user_associations: if ug_assoc.user: # Ensure user object is loaded participating_users[ug_assoc.user.id] = ug_assoc.user else: # If personal list, only the creator participates (or if items were added by others somehow, include them) # For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit. # Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator) creator_user = await db.get(UserModel, db_list.created_by_id) if not creator_user: # This case should ideally not happen if data integrity is maintained raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error participating_users[creator_user.id] = creator_user # Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness) for item in db_list.items: if item.added_by_user and item.added_by_user.id not in participating_users: participating_users[item.added_by_user.id] = item.added_by_user num_participating_users = len(participating_users) if num_participating_users == 0: # Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent) # Or if list has no items and is personal, creator might not be in participating_users if logic changes. # For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary. return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) # 3. Calculate total cost and items_added_value for each user total_list_cost = Decimal("0.00") user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()} for item in db_list.items: if item.price is not None and item.price > Decimal("0"): total_list_cost += item.price if item.added_by_id in user_items_added_value: # Item adder must be in participating users user_items_added_value[item.added_by_id] += item.price # If item.added_by_id is not in participating_users (e.g. user left group), # their contribution still counts to total cost, but they aren't part of the split. # The current logic adds item adders to participating_users, so this else is less likely. # 4. Calculate equal share per user # Using ROUND_HALF_UP to handle cents appropriately. # Ensure division by zero is handled if num_participating_users could be 0 (already handled above) equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00") # 5. For each user, calculate their balance user_balances: PyList[UserCostShare] = [] for user_id, user_obj in participating_users.items(): items_added = user_items_added_value.get(user_id, Decimal("0.00")) balance = items_added - equal_share_per_user user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email user_balances.append( UserCostShare( user_id=user_id, user_identifier=user_identifier, items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), amount_due=equal_share_per_user, # Already quantized balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) ) ) # Sort user_balances for consistent output, e.g., by user_id or identifier user_balances.sort(key=lambda x: x.user_identifier) # 6. Return the populated ListCostSummary schema return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=num_participating_users, equal_share_per_user=equal_share_per_user, # Already quantized user_balances=user_balances )