remove old cost splitter
This commit is contained in:
parent
bbb3c3b7df
commit
e3024ccd07
@ -4,13 +4,25 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
|
||||
from app.database import get_db
|
||||
from app.api.dependencies import get_current_user
|
||||
from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model
|
||||
from app.models import (
|
||||
User as UserModel,
|
||||
Group as GroupModel,
|
||||
List as ListModel,
|
||||
Expense as ExpenseModel,
|
||||
Item as ItemModel,
|
||||
UserGroup as UserGroupModel,
|
||||
SplitTypeEnum,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
Settlement as SettlementModel
|
||||
)
|
||||
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
|
||||
from app.crud import cost as crud_cost
|
||||
from app.crud import list as crud_list # For permission checking
|
||||
from app.schemas.expense import ExpenseCreate
|
||||
from app.crud import list as crud_list
|
||||
from app.crud import expense as crud_expense
|
||||
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -47,28 +59,117 @@ async def get_list_cost_summary(
|
||||
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
||||
except ListPermissionError as e:
|
||||
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
|
||||
raise # Re-raise the original exception to be handled by FastAPI
|
||||
raise
|
||||
except ListNotFoundError as e:
|
||||
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
|
||||
raise # Re-raise
|
||||
raise
|
||||
|
||||
# 2. Calculate the cost summary
|
||||
try:
|
||||
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id)
|
||||
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}")
|
||||
return cost_summary
|
||||
except ListNotFoundError as e:
|
||||
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}")
|
||||
# This might be redundant if check_list_permission already confirmed list existence,
|
||||
# but calculate_list_cost_summary also fetches the list.
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except UserNotFoundError as e:
|
||||
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}")
|
||||
# This indicates a data integrity issue (e.g., list creator or item adder missing)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
|
||||
# 2. Get the list with its items and users
|
||||
list_result = await db.execute(
|
||||
select(ListModel)
|
||||
.options(
|
||||
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
|
||||
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
|
||||
selectinload(ListModel.creator)
|
||||
)
|
||||
.where(ListModel.id == list_id)
|
||||
)
|
||||
db_list = list_result.scalars().first()
|
||||
if not db_list:
|
||||
raise ListNotFoundError(list_id)
|
||||
|
||||
# 3. Get or create an expense for this list
|
||||
expense_result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.list_id == list_id)
|
||||
.options(selectinload(ExpenseModel.splits))
|
||||
)
|
||||
db_expense = expense_result.scalars().first()
|
||||
|
||||
if not db_expense:
|
||||
# Create a new expense for this list
|
||||
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
|
||||
if total_amount == Decimal("0"):
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id,
|
||||
list_name=db_list.name,
|
||||
total_list_cost=Decimal("0.00"),
|
||||
num_participating_users=0,
|
||||
equal_share_per_user=Decimal("0.00"),
|
||||
user_balances=[]
|
||||
)
|
||||
|
||||
# Create expense with ITEM_BASED split type
|
||||
expense_in = ExpenseCreate(
|
||||
description=f"Cost summary for list {db_list.name}",
|
||||
total_amount=total_amount,
|
||||
list_id=list_id,
|
||||
split_type=SplitTypeEnum.ITEM_BASED,
|
||||
paid_by_user_id=current_user.id # Use current user as payer for now
|
||||
)
|
||||
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
|
||||
|
||||
# 4. Calculate cost summary from expense splits
|
||||
participating_users = set()
|
||||
user_items_added_value = {}
|
||||
total_list_cost = Decimal("0.00")
|
||||
|
||||
# Get all users who added items
|
||||
for item in db_list.items:
|
||||
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
|
||||
participating_users.add(item.added_by_user)
|
||||
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
|
||||
total_list_cost += item.price
|
||||
|
||||
# Get all users from expense splits
|
||||
for split in db_expense.splits:
|
||||
if split.user:
|
||||
participating_users.add(split.user)
|
||||
|
||||
num_participating_users = len(participating_users)
|
||||
if num_participating_users == 0:
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id,
|
||||
list_name=db_list.name,
|
||||
total_list_cost=Decimal("0.00"),
|
||||
num_participating_users=0,
|
||||
equal_share_per_user=Decimal("0.00"),
|
||||
user_balances=[]
|
||||
)
|
||||
|
||||
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
|
||||
|
||||
user_balances = []
|
||||
first_user_processed = False
|
||||
for user in participating_users:
|
||||
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
|
||||
current_user_share = equal_share_per_user
|
||||
if not first_user_processed and remainder != Decimal("0"):
|
||||
current_user_share += remainder
|
||||
first_user_processed = True
|
||||
|
||||
balance = items_added - current_user_share
|
||||
user_identifier = user.name if user.name else user.email
|
||||
user_balances.append(
|
||||
UserCostShare(
|
||||
user_id=user.id,
|
||||
user_identifier=user_identifier,
|
||||
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
)
|
||||
)
|
||||
|
||||
user_balances.sort(key=lambda x: x.user_identifier)
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id,
|
||||
list_name=db_list.name,
|
||||
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
num_participating_users=num_participating_users,
|
||||
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
user_balances=user_balances
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/groups/{group_id}/balance-summary",
|
||||
@ -92,11 +193,7 @@ async def get_group_balance_summary(
|
||||
"""
|
||||
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
|
||||
|
||||
# 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar)
|
||||
# Assuming a function like this exists in app.crud.group or we add it.
|
||||
# For now, let's placeholder this check logic.
|
||||
# await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id)
|
||||
# A simpler check for now: fetch the group and see if user is part of member_associations
|
||||
# 1. Verify user is a member of the target group
|
||||
group_check = await db.execute(
|
||||
select(GroupModel)
|
||||
.options(selectinload(GroupModel.member_associations))
|
||||
@ -109,20 +206,75 @@ async def get_group_balance_summary(
|
||||
|
||||
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
|
||||
if not user_is_member:
|
||||
# If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
|
||||
|
||||
# 2. Calculate the group balance summary
|
||||
try:
|
||||
balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id)
|
||||
logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}")
|
||||
return balance_summary
|
||||
except GroupNotFoundError as e:
|
||||
logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except UserNotFoundError as e: # Should not happen if group members are correctly fetched
|
||||
logger.error(f"User not found during balance summary for group {group_id}: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.")
|
||||
# 2. Get all expenses and settlements for the group
|
||||
expenses_result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.group_id == group_id)
|
||||
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
|
||||
)
|
||||
expenses = expenses_result.scalars().all()
|
||||
|
||||
settlements_result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.group_id == group_id)
|
||||
.options(
|
||||
selectinload(SettlementModel.paid_by_user),
|
||||
selectinload(SettlementModel.paid_to_user)
|
||||
)
|
||||
)
|
||||
settlements = settlements_result.scalars().all()
|
||||
|
||||
# 3. Calculate user balances
|
||||
user_balances_data = {}
|
||||
for assoc in db_group_for_check.member_associations:
|
||||
if assoc.user:
|
||||
user_balances_data[assoc.user.id] = UserBalanceDetail(
|
||||
user_id=assoc.user.id,
|
||||
user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
|
||||
)
|
||||
|
||||
# Process expenses
|
||||
for expense in expenses:
|
||||
if expense.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
|
||||
|
||||
for split in expense.splits:
|
||||
if split.user_id in user_balances_data:
|
||||
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
|
||||
|
||||
# Process settlements
|
||||
for settlement in settlements:
|
||||
if settlement.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
|
||||
if settlement.paid_to_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
|
||||
|
||||
# Calculate net balances
|
||||
final_user_balances = []
|
||||
for user_id, data in user_balances_data.items():
|
||||
data.net_balance = (
|
||||
data.total_paid_for_expenses + data.total_settlements_received
|
||||
) - (data.total_share_of_expenses + data.total_settlements_paid)
|
||||
|
||||
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
final_user_balances.append(data)
|
||||
|
||||
# Sort by user identifier
|
||||
final_user_balances.sort(key=lambda x: x.user_identifier)
|
||||
|
||||
# Calculate suggested settlements
|
||||
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
||||
|
||||
return GroupBalanceSummary(
|
||||
group_id=db_group_for_check.id,
|
||||
group_name=db_group_for_check.name,
|
||||
user_balances=final_user_balances,
|
||||
suggested_settlements=suggested_settlements
|
||||
)
|
@ -60,6 +60,7 @@ Organic Bananas
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"http://localhost:5174",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:9000",
|
||||
# Add your deployed frontend URL here later
|
||||
# "https://your-frontend-domain.com",
|
||||
]
|
||||
|
@ -1,266 +0,0 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from sqlalchemy import func as sql_func, or_
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import Dict, Optional, Sequence, List as PyList
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
from app.models import (
|
||||
List as ListModel,
|
||||
Item as ItemModel,
|
||||
User as UserModel,
|
||||
UserGroup as UserGroupModel,
|
||||
Group as GroupModel,
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
Settlement as SettlementModel
|
||||
)
|
||||
from app.schemas.cost import (
|
||||
ListCostSummary,
|
||||
UserCostShare,
|
||||
GroupBalanceSummary,
|
||||
UserBalanceDetail,
|
||||
SuggestedSettlement
|
||||
)
|
||||
from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
UserNotFoundError,
|
||||
GroupNotFoundError,
|
||||
InvalidOperationError
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
|
||||
"""
|
||||
Calculates the cost summary for a given list based purely on item prices and who added them.
|
||||
This is a simpler calculation and does not involve the Expense/Settlement system.
|
||||
"""
|
||||
list_result = await db.execute(
|
||||
select(ListModel)
|
||||
.options(
|
||||
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
|
||||
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
|
||||
selectinload(ListModel.creator)
|
||||
)
|
||||
.where(ListModel.id == list_id)
|
||||
)
|
||||
db_list: Optional[ListModel] = list_result.scalars().first()
|
||||
|
||||
if not db_list:
|
||||
raise ListNotFoundError(list_id)
|
||||
|
||||
participating_users_map: Dict[int, UserModel] = {}
|
||||
if db_list.group:
|
||||
for ug_assoc in db_list.group.user_associations:
|
||||
if ug_assoc.user:
|
||||
participating_users_map[ug_assoc.user.id] = ug_assoc.user
|
||||
elif db_list.creator: # Personal list
|
||||
participating_users_map[db_list.creator.id] = db_list.creator
|
||||
|
||||
# Include all users who added items with prices, even if not in the primary context (group/creator)
|
||||
for item in db_list.items:
|
||||
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
|
||||
participating_users_map[item.added_by_user.id] = item.added_by_user
|
||||
|
||||
num_participating_users = len(participating_users_map)
|
||||
if num_participating_users == 0:
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
|
||||
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
|
||||
)
|
||||
|
||||
total_list_cost = Decimal("0.00")
|
||||
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
|
||||
|
||||
for item in db_list.items:
|
||||
if item.price is not None and item.price > Decimal("0"):
|
||||
total_list_cost += item.price
|
||||
if item.added_by_id in participating_users_map:
|
||||
user_items_added_value[item.added_by_id] += item.price
|
||||
|
||||
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
|
||||
|
||||
user_balances: PyList[UserCostShare] = []
|
||||
first_user_processed = False
|
||||
for user_id, user_obj in participating_users_map.items():
|
||||
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
|
||||
current_user_share = equal_share_per_user
|
||||
if not first_user_processed and remainder != Decimal("0"):
|
||||
current_user_share += remainder
|
||||
first_user_processed = True
|
||||
|
||||
balance = items_added - current_user_share
|
||||
user_identifier = user_obj.name if user_obj.name else user_obj.email
|
||||
user_balances.append(
|
||||
UserCostShare(
|
||||
user_id=user_id, user_identifier=user_identifier,
|
||||
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
)
|
||||
)
|
||||
user_balances.sort(key=lambda x: x.user_identifier)
|
||||
return ListCostSummary(
|
||||
list_id=db_list.id, list_name=db_list.name,
|
||||
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
num_participating_users=num_participating_users,
|
||||
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
user_balances=user_balances
|
||||
)
|
||||
|
||||
# --- Helper for Settlement Suggestions ---
|
||||
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
|
||||
"""
|
||||
Calculates a list of suggested settlements to resolve group debts.
|
||||
Uses a greedy algorithm to minimize the number of transactions.
|
||||
Input: List of UserBalanceDetail objects with calculated net_balances.
|
||||
Output: List of SuggestedSettlement objects.
|
||||
"""
|
||||
# Use a small tolerance for floating point comparisons with Decimal
|
||||
tolerance = Decimal("0.001")
|
||||
|
||||
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
|
||||
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
|
||||
|
||||
settlements: PyList[SuggestedSettlement] = []
|
||||
debtor_idx = 0
|
||||
creditor_idx = 0
|
||||
|
||||
# Create mutable copies of balances to track remaining amounts
|
||||
debtor_balances = {d.user_id: d.net_balance for d in debtors}
|
||||
creditor_balances = {c.user_id: c.net_balance for c in creditors}
|
||||
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
|
||||
|
||||
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
|
||||
debtor = debtors[debtor_idx]
|
||||
creditor = creditors[creditor_idx]
|
||||
|
||||
debtor_remaining = debtor_balances[debtor.user_id]
|
||||
creditor_remaining = creditor_balances[creditor.user_id]
|
||||
|
||||
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
|
||||
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
|
||||
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
if transfer_amount > tolerance: # Only record meaningful transfers
|
||||
settlements.append(SuggestedSettlement(
|
||||
from_user_id=debtor.user_id,
|
||||
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
|
||||
to_user_id=creditor.user_id,
|
||||
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
|
||||
amount=transfer_amount
|
||||
))
|
||||
|
||||
# Update remaining balances
|
||||
debtor_balances[debtor.user_id] += transfer_amount
|
||||
creditor_balances[creditor.user_id] -= transfer_amount
|
||||
|
||||
# Move to next debtor if current one is settled (or very close)
|
||||
if abs(debtor_balances[debtor.user_id]) < tolerance:
|
||||
debtor_idx += 1
|
||||
|
||||
# Move to next creditor if current one is settled (or very close)
|
||||
if creditor_balances[creditor.user_id] < tolerance:
|
||||
creditor_idx += 1
|
||||
|
||||
# Log if lists aren't empty - indicates potential imbalance or rounding issue
|
||||
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
|
||||
# Calculate remaining balances for logging
|
||||
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
|
||||
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
|
||||
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
|
||||
|
||||
return settlements
|
||||
|
||||
# --- NEW: Detailed Group Balance Summary ---
|
||||
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
|
||||
"""
|
||||
Calculates a detailed balance summary for all users in a group,
|
||||
considering all expenses, splits, and settlements within that group.
|
||||
Also calculates suggested settlements.
|
||||
"""
|
||||
group = await db.get(GroupModel, group_id)
|
||||
if not group:
|
||||
raise GroupNotFoundError(group_id)
|
||||
|
||||
# 1. Get all group members
|
||||
group_members_result = await db.execute(
|
||||
select(UserModel)
|
||||
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
|
||||
.where(UserGroupModel.group_id == group_id)
|
||||
)
|
||||
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
|
||||
if not group_members:
|
||||
return GroupBalanceSummary(
|
||||
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
|
||||
)
|
||||
|
||||
user_balances_data: Dict[int, UserBalanceDetail] = {}
|
||||
for user_id, user_obj in group_members.items():
|
||||
user_balances_data[user_id] = UserBalanceDetail(
|
||||
user_id=user_id,
|
||||
user_identifier=user_obj.name if user_obj.name else user_obj.email
|
||||
)
|
||||
|
||||
overall_total_expenses = Decimal("0.00")
|
||||
overall_total_settlements = Decimal("0.00")
|
||||
|
||||
# 2. Process Expenses and ExpenseSplits for the group
|
||||
expenses_result = await db.execute(
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.group_id == group_id)
|
||||
.options(selectinload(ExpenseModel.splits))
|
||||
)
|
||||
for expense in expenses_result.scalars().all():
|
||||
overall_total_expenses += expense.total_amount
|
||||
if expense.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
|
||||
|
||||
for split in expense.splits:
|
||||
if split.user_id in user_balances_data:
|
||||
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
|
||||
|
||||
# 3. Process Settlements for the group
|
||||
settlements_result = await db.execute(
|
||||
select(SettlementModel).where(SettlementModel.group_id == group_id)
|
||||
)
|
||||
for settlement in settlements_result.scalars().all():
|
||||
overall_total_settlements += settlement.amount
|
||||
if settlement.paid_by_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
|
||||
if settlement.paid_to_user_id in user_balances_data:
|
||||
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
|
||||
|
||||
# 4. Calculate net balances and prepare final list
|
||||
final_user_balances: PyList[UserBalanceDetail] = []
|
||||
for user_id in group_members.keys():
|
||||
data = user_balances_data[user_id]
|
||||
data.net_balance = (
|
||||
data.total_paid_for_expenses + data.total_settlements_received
|
||||
) - (data.total_share_of_expenses + data.total_settlements_paid)
|
||||
|
||||
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
final_user_balances.append(data)
|
||||
|
||||
final_user_balances.sort(key=lambda x: x.user_identifier)
|
||||
|
||||
# 5. Calculate suggested settlements (NEW)
|
||||
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
||||
|
||||
return GroupBalanceSummary(
|
||||
group_id=group.id,
|
||||
group_name=group.name,
|
||||
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
user_balances=final_user_balances,
|
||||
suggested_settlements=suggested_settlements # Add suggestions to response
|
||||
)
|
@ -1,254 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List as PyList, Dict
|
||||
import logging # For mocking the logger
|
||||
|
||||
from app.crud.cost import (
|
||||
calculate_list_cost_summary,
|
||||
calculate_suggested_settlements,
|
||||
calculate_group_balance_summary
|
||||
)
|
||||
from app.schemas.cost import (
|
||||
ListCostSummary,
|
||||
UserCostShare,
|
||||
GroupBalanceSummary,
|
||||
UserBalanceDetail,
|
||||
SuggestedSettlement
|
||||
)
|
||||
from app.models import (
|
||||
List as ListModel,
|
||||
Item as ItemModel,
|
||||
User as UserModel,
|
||||
Group as GroupModel,
|
||||
UserGroup as UserGroupModel,
|
||||
Expense as ExpenseModel,
|
||||
ExpenseSplit as ExpenseSplitModel,
|
||||
Settlement as SettlementModel
|
||||
)
|
||||
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.get = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def user_a():
|
||||
return UserModel(id=1, name="User A", email="a@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def user_b():
|
||||
return UserModel(id=2, name="User B", email="b@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def user_c():
|
||||
return UserModel(id=3, name="User C", email="c@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def basic_list_model(user_a):
|
||||
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
|
||||
|
||||
@pytest.fixture
|
||||
def group_model_with_members(user_a, user_b):
|
||||
group = GroupModel(id=1, name="Test Group")
|
||||
# Simulate user_associations for calculate_list_cost_summary if list is group-based
|
||||
group.user_associations = [
|
||||
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
|
||||
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
|
||||
]
|
||||
return group
|
||||
|
||||
@pytest.fixture
|
||||
def list_model_with_group(group_model_with_members):
|
||||
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
|
||||
|
||||
# --- calculate_list_cost_summary Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
|
||||
basic_list_model.items = [] # Ensure no items
|
||||
basic_list_model.group = None # Ensure it's a personal list
|
||||
basic_list_model.creator = user_a
|
||||
|
||||
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
|
||||
|
||||
assert summary.list_id == basic_list_model.id
|
||||
assert summary.total_list_cost == Decimal("0.00")
|
||||
assert summary.num_participating_users == 1
|
||||
assert summary.equal_share_per_user == Decimal("0.00")
|
||||
assert len(summary.user_balances) == 1
|
||||
assert summary.user_balances[0].user_id == user_a.id
|
||||
assert summary.user_balances[0].balance == Decimal("0.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
|
||||
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
|
||||
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
|
||||
list_model_with_group.items = [item1, item2]
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
|
||||
|
||||
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
|
||||
|
||||
assert summary.total_list_cost == Decimal("5.00")
|
||||
assert summary.num_participating_users == 2
|
||||
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
|
||||
|
||||
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
||||
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
|
||||
assert balances_map[user_a.id].amount_due == Decimal("2.50")
|
||||
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
|
||||
|
||||
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
|
||||
assert balances_map[user_b.id].amount_due == Decimal("2.50")
|
||||
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
|
||||
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
with pytest.raises(ListNotFoundError):
|
||||
await calculate_list_cost_summary(mock_db_session, 999)
|
||||
|
||||
# --- calculate_suggested_settlements Tests ---
|
||||
@patch('app.crud.cost.logger') # Mock the logger used in the function
|
||||
def test_calculate_suggested_settlements_simple_case(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
assert len(settlements) == 1
|
||||
assert settlements[0].from_user_id == 1
|
||||
assert settlements[0].to_user_id == 2
|
||||
assert settlements[0].amount == Decimal("10.00")
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
@patch('app.crud.cost.logger')
|
||||
def test_calculate_suggested_settlements_multiple(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
||||
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
# Expected: A owes B 10, A owes C 20
|
||||
assert len(settlements) == 2
|
||||
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
|
||||
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
|
||||
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
@patch('app.crud.cost.logger')
|
||||
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
|
||||
user_balances = [
|
||||
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
|
||||
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
|
||||
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
|
||||
]
|
||||
settlements = calculate_suggested_settlements(user_balances)
|
||||
# A -> C 33.33, B -> C 33.33
|
||||
assert len(settlements) == 2
|
||||
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
|
||||
assert total_settled_to_C == Decimal("66.66")
|
||||
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
|
||||
|
||||
# --- calculate_group_balance_summary Tests ---
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
|
||||
mock_db_session.get.return_value = group_model_with_members # For group fetch
|
||||
|
||||
# Mock group members query
|
||||
mock_members_result = AsyncMock()
|
||||
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
|
||||
|
||||
# Mock expenses query (no expenses)
|
||||
mock_expenses_result = AsyncMock()
|
||||
mock_expenses_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Mock settlements query (no settlements)
|
||||
mock_settlements_result = AsyncMock()
|
||||
mock_settlements_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
||||
|
||||
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
||||
|
||||
assert summary.group_id == group_model_with_members.id
|
||||
assert len(summary.user_balances) == 2
|
||||
assert len(summary.suggested_settlements) == 0
|
||||
for ub in summary.user_balances:
|
||||
assert ub.net_balance == Decimal("0.00")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
|
||||
# Group members for this test: A, B, C
|
||||
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
|
||||
all_users = [user_a, user_b, user_c]
|
||||
|
||||
mock_db_session.get.return_value = group_model_with_members
|
||||
|
||||
mock_members_result = AsyncMock()
|
||||
mock_members_result.scalars.return_value.all.return_value = all_users
|
||||
|
||||
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
|
||||
# A: paid 90, share 30 -> +60
|
||||
# B: paid 0, share 30 -> -30
|
||||
# C: paid 0, share 30 -> -30
|
||||
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
|
||||
expense1.splits = [
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
|
||||
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
|
||||
]
|
||||
mock_expenses_result = AsyncMock()
|
||||
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
|
||||
|
||||
# Settlements: B paid A 10
|
||||
# A: received 10 -> +60 + 10 = +70
|
||||
# B: paid 10 -> -30 - 10 = -40
|
||||
# C: no settlement -> -30
|
||||
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
|
||||
mock_settlements_result = AsyncMock()
|
||||
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
|
||||
|
||||
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
||||
|
||||
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
|
||||
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
||||
|
||||
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
||||
assert balances_map[user_a.id].net_balance == Decimal("70.00")
|
||||
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
|
||||
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
|
||||
|
||||
# Suggested settlements: B owes A 30 (40-10), C owes A 30
|
||||
# Net balances: A: +70, B: -40, C: -30
|
||||
# Algorithm should suggest: B -> A (40), C -> A (30)
|
||||
assert len(summary.suggested_settlements) == 2
|
||||
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
|
||||
|
||||
# Order might vary, check presence and amounts
|
||||
assert (user_b.id, user_a.id) in settlement_map
|
||||
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
|
||||
|
||||
assert (user_c.id, user_a.id) in settlement_map
|
||||
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
|
||||
mock_settlement_logger.warning.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
|
||||
mock_db_session.get.return_value = None
|
||||
with pytest.raises(GroupNotFoundError):
|
||||
await calculate_group_balance_summary(mock_db_session, 999)
|
||||
|
||||
# TODO:
|
||||
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
|
||||
# - Test calculate_list_cost_summary with remainder distribution for equal share.
|
||||
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
|
||||
# - Test calculate_suggested_settlements with logger warning for remaining balances.
|
||||
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
|
||||
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).
|
Loading…
Reference in New Issue
Block a user