116 lines
5.9 KiB
Python
116 lines
5.9 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from typing import List as PyList, Dict, Set
|
|
|
|
from app.models import List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel
|
|
from app.schemas.cost import ListCostSummary, UserCostShare
|
|
from app.core.exceptions import ListNotFoundError, UserNotFoundError # Assuming UserNotFoundError might be useful
|
|
|
|
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
|
|
"""
|
|
Calculates the cost summary for a given list, splitting costs equally
|
|
among relevant users (group members if list is in a group, or creator if personal).
|
|
"""
|
|
# 1. Fetch the list, its items (with their 'added_by_user'), and its group (with members)
|
|
list_result = await db.execute(
|
|
select(ListModel)
|
|
.options(
|
|
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
|
|
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user)))
|
|
)
|
|
.where(ListModel.id == list_id)
|
|
)
|
|
db_list: Optional[ListModel] = list_result.scalars().first()
|
|
|
|
if not db_list:
|
|
raise ListNotFoundError(list_id)
|
|
|
|
# 2. Determine participating users
|
|
participating_users: Dict[int, UserModel] = {}
|
|
if db_list.group:
|
|
# If list is part of a group, all group members participate
|
|
for ug_assoc in db_list.group.user_associations:
|
|
if ug_assoc.user: # Ensure user object is loaded
|
|
participating_users[ug_assoc.user.id] = ug_assoc.user
|
|
else:
|
|
# If personal list, only the creator participates (or if items were added by others somehow, include them)
|
|
# For simplicity in MVP, if personal, only creator. If shared personal lists become a feature, this needs revisit.
|
|
# Let's fetch the creator if not already available through relationships (though it should be via ListModel.creator)
|
|
creator_user = await db.get(UserModel, db_list.created_by_id)
|
|
if not creator_user:
|
|
# This case should ideally not happen if data integrity is maintained
|
|
raise UserNotFoundError(user_id=db_list.created_by_id) # Or a more specific error
|
|
participating_users[creator_user.id] = creator_user
|
|
|
|
# Also ensure all users who added items are included, even if not in the group (edge case, but good for robustness)
|
|
for item in db_list.items:
|
|
if item.added_by_user and item.added_by_user.id not in participating_users:
|
|
participating_users[item.added_by_user.id] = item.added_by_user
|
|
|
|
|
|
num_participating_users = len(participating_users)
|
|
if num_participating_users == 0:
|
|
# Handle case with no users (e.g., empty group, or personal list creator deleted - though FKs should prevent)
|
|
# Or if list has no items and is personal, creator might not be in participating_users if logic changes.
|
|
# For now, if no users found (e.g. group with no members and list creator somehow not added), return empty/default summary.
|
|
return ListCostSummary(
|
|
list_id=db_list.id,
|
|
list_name=db_list.name,
|
|
total_list_cost=Decimal("0.00"),
|
|
num_participating_users=0,
|
|
equal_share_per_user=Decimal("0.00"),
|
|
user_balances=[]
|
|
)
|
|
|
|
|
|
# 3. Calculate total cost and items_added_value for each user
|
|
total_list_cost = Decimal("0.00")
|
|
user_items_added_value: Dict[int, Decimal] = {user_id: Decimal("0.00") for user_id in participating_users.keys()}
|
|
|
|
for item in db_list.items:
|
|
if item.price is not None and item.price > Decimal("0"):
|
|
total_list_cost += item.price
|
|
if item.added_by_id in user_items_added_value: # Item adder must be in participating users
|
|
user_items_added_value[item.added_by_id] += item.price
|
|
# If item.added_by_id is not in participating_users (e.g. user left group),
|
|
# their contribution still counts to total cost, but they aren't part of the split.
|
|
# The current logic adds item adders to participating_users, so this else is less likely.
|
|
|
|
# 4. Calculate equal share per user
|
|
# Using ROUND_HALF_UP to handle cents appropriately.
|
|
# Ensure division by zero is handled if num_participating_users could be 0 (already handled above)
|
|
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if num_participating_users > 0 else Decimal("0.00")
|
|
|
|
# 5. For each user, calculate their balance
|
|
user_balances: PyList[UserCostShare] = []
|
|
for user_id, user_obj in participating_users.items():
|
|
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
|
|
balance = items_added - equal_share_per_user
|
|
|
|
user_identifier = user_obj.name if user_obj.name else user_obj.email # Prefer name, fallback to email
|
|
|
|
user_balances.append(
|
|
UserCostShare(
|
|
user_id=user_id,
|
|
user_identifier=user_identifier,
|
|
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
amount_due=equal_share_per_user, # Already quantized
|
|
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
)
|
|
)
|
|
|
|
# Sort user_balances for consistent output, e.g., by user_id or identifier
|
|
user_balances.sort(key=lambda x: x.user_identifier)
|
|
|
|
|
|
# 6. Return the populated ListCostSummary schema
|
|
return ListCostSummary(
|
|
list_id=db_list.id,
|
|
list_name=db_list.name,
|
|
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
num_participating_users=num_participating_users,
|
|
equal_share_per_user=equal_share_per_user, # Already quantized
|
|
user_balances=user_balances
|
|
) |