379 lines
16 KiB
Python
379 lines
16 KiB
Python
# app/api/v1/endpoints/costs.py
|
|
import logging
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session, selectinload
|
|
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
|
|
from typing import List
|
|
|
|
from app.database import get_transactional_session
|
|
from app.auth import current_active_user
|
|
from app.models import (
|
|
User as UserModel,
|
|
Group as GroupModel,
|
|
List as ListModel,
|
|
Expense as ExpenseModel,
|
|
Item as ItemModel,
|
|
UserGroup as UserGroupModel,
|
|
SplitTypeEnum,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
Settlement as SettlementModel
|
|
)
|
|
from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement
|
|
from app.schemas.expense import ExpenseCreate
|
|
from app.crud import list as crud_list
|
|
from app.crud import expense as crud_expense
|
|
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]:
|
|
"""
|
|
Calculate suggested settlements to balance the finances within a group.
|
|
|
|
This function takes the current balances of all users and suggests optimal settlements
|
|
to minimize the number of transactions needed to settle all debts.
|
|
|
|
Args:
|
|
user_balances: List of UserBalanceDetail objects with their current balances
|
|
|
|
Returns:
|
|
List of SuggestedSettlement objects representing the suggested payments
|
|
"""
|
|
# Create list of users who owe money (negative balance) and who are owed money (positive balance)
|
|
debtors = [] # Users who owe money (negative balance)
|
|
creditors = [] # Users who are owed money (positive balance)
|
|
|
|
# Threshold to consider a balance as zero due to floating point precision
|
|
epsilon = Decimal('0.01')
|
|
|
|
# Sort users into debtors and creditors
|
|
for user in user_balances:
|
|
# Skip users with zero balance (or very close to zero)
|
|
if abs(user.net_balance) < epsilon:
|
|
continue
|
|
|
|
if user.net_balance < Decimal('0'):
|
|
# User owes money
|
|
debtors.append({
|
|
'user_id': user.user_id,
|
|
'user_identifier': user.user_identifier,
|
|
'amount': -user.net_balance # Convert to positive amount
|
|
})
|
|
else:
|
|
# User is owed money
|
|
creditors.append({
|
|
'user_id': user.user_id,
|
|
'user_identifier': user.user_identifier,
|
|
'amount': user.net_balance
|
|
})
|
|
|
|
# Sort by amount (descending) to handle largest debts first
|
|
debtors.sort(key=lambda x: x['amount'], reverse=True)
|
|
creditors.sort(key=lambda x: x['amount'], reverse=True)
|
|
|
|
settlements = []
|
|
|
|
# Iterate through debtors and match them with creditors
|
|
while debtors and creditors:
|
|
debtor = debtors[0]
|
|
creditor = creditors[0]
|
|
|
|
# Determine the settlement amount (the smaller of the two amounts)
|
|
amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
|
|
|
|
# Create settlement record
|
|
if amount > Decimal('0'):
|
|
settlements.append(
|
|
SuggestedSettlement(
|
|
from_user_id=debtor['user_id'],
|
|
from_user_identifier=debtor['user_identifier'],
|
|
to_user_id=creditor['user_id'],
|
|
to_user_identifier=creditor['user_identifier'],
|
|
amount=amount
|
|
)
|
|
)
|
|
|
|
# Update balances
|
|
debtor['amount'] -= amount
|
|
creditor['amount'] -= amount
|
|
|
|
# Remove users who have settled their debts/credits
|
|
if debtor['amount'] < epsilon:
|
|
debtors.pop(0)
|
|
if creditor['amount'] < epsilon:
|
|
creditors.pop(0)
|
|
|
|
return settlements
|
|
|
|
@router.get(
|
|
"/lists/{list_id}/cost-summary",
|
|
response_model=ListCostSummary,
|
|
summary="Get Cost Summary for a List",
|
|
tags=["Costs"],
|
|
responses={
|
|
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
|
|
status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
|
|
}
|
|
)
|
|
async def get_list_cost_summary(
|
|
list_id: int,
|
|
db: AsyncSession = Depends(get_transactional_session),
|
|
current_user: UserModel = Depends(current_active_user),
|
|
):
|
|
"""
|
|
Retrieves a calculated cost summary for a specific list, detailing total costs,
|
|
equal shares per user, and individual user balances based on their contributions.
|
|
|
|
The user must have access to the list to view its cost summary.
|
|
Costs are split among group members if the list belongs to a group, or just for
|
|
the creator if it's a personal list. All users who added items with prices are
|
|
included in the calculation.
|
|
"""
|
|
logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")
|
|
|
|
# 1. Verify user has access to the target list
|
|
try:
|
|
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
|
|
except ListPermissionError as e:
|
|
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
|
|
raise
|
|
except ListNotFoundError as e:
|
|
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
|
|
raise
|
|
|
|
# 2. Get the list with its items and users
|
|
list_result = await db.execute(
|
|
select(ListModel)
|
|
.options(
|
|
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
|
|
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
|
selectinload(ListModel.creator)
|
|
)
|
|
.where(ListModel.id == list_id)
|
|
)
|
|
db_list = list_result.scalars().first()
|
|
if not db_list:
|
|
raise ListNotFoundError(list_id)
|
|
|
|
# 3. Get or create an expense for this list
|
|
expense_result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.list_id == list_id)
|
|
.options(selectinload(ExpenseModel.splits))
|
|
)
|
|
db_expense = expense_result.scalars().first()
|
|
|
|
if not db_expense:
|
|
# Create a new expense for this list
|
|
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
|
|
if total_amount == Decimal("0"):
|
|
return ListCostSummary(
|
|
list_id=db_list.id,
|
|
list_name=db_list.name,
|
|
total_list_cost=Decimal("0.00"),
|
|
num_participating_users=0,
|
|
equal_share_per_user=Decimal("0.00"),
|
|
user_balances=[]
|
|
)
|
|
|
|
# Create expense with ITEM_BASED split type
|
|
expense_in = ExpenseCreate(
|
|
description=f"Cost summary for list {db_list.name}",
|
|
total_amount=total_amount,
|
|
list_id=list_id,
|
|
split_type=SplitTypeEnum.ITEM_BASED,
|
|
paid_by_user_id=db_list.creator.id
|
|
)
|
|
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
|
|
|
|
# 4. Calculate cost summary from expense splits
|
|
participating_users = set()
|
|
user_items_added_value = {}
|
|
total_list_cost = Decimal("0.00")
|
|
|
|
# Get all users who added items
|
|
for item in db_list.items:
|
|
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
|
|
participating_users.add(item.added_by_user)
|
|
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
|
|
total_list_cost += item.price
|
|
|
|
# Get all users from expense splits
|
|
for split in db_expense.splits:
|
|
if split.user:
|
|
participating_users.add(split.user)
|
|
|
|
num_participating_users = len(participating_users)
|
|
if num_participating_users == 0:
|
|
return ListCostSummary(
|
|
list_id=db_list.id,
|
|
list_name=db_list.name,
|
|
total_list_cost=Decimal("0.00"),
|
|
num_participating_users=0,
|
|
equal_share_per_user=Decimal("0.00"),
|
|
user_balances=[]
|
|
)
|
|
|
|
# This is the ideal equal share, returned in the summary
|
|
equal_share_per_user_for_response = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
|
|
# Sort users for deterministic remainder distribution
|
|
sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id)
|
|
|
|
user_final_shares = {}
|
|
if num_participating_users > 0:
|
|
base_share_unrounded = total_list_cost / Decimal(num_participating_users)
|
|
|
|
# Calculate initial share for each user, rounding down
|
|
for user in sorted_participating_users:
|
|
user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN)
|
|
|
|
# Calculate sum of rounded down shares
|
|
sum_of_rounded_shares = sum(user_final_shares.values())
|
|
|
|
# Calculate remaining pennies to be distributed
|
|
remaining_pennies = int(((total_list_cost - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP))
|
|
|
|
# Distribute remaining pennies one by one to sorted users
|
|
for i in range(remaining_pennies):
|
|
user_to_adjust = sorted_participating_users[i % num_participating_users]
|
|
user_final_shares[user_to_adjust.id] += Decimal("0.01")
|
|
|
|
user_balances = []
|
|
for user in sorted_participating_users: # Iterate over sorted users
|
|
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
|
|
# current_user_share is now the precisely calculated share for this user
|
|
current_user_share = user_final_shares.get(user.id, Decimal("0.00"))
|
|
|
|
balance = items_added - current_user_share
|
|
user_identifier = user.name if user.name else user.email
|
|
user_balances.append(
|
|
UserCostShare(
|
|
user_id=user.id,
|
|
user_identifier=user_identifier,
|
|
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
)
|
|
)
|
|
|
|
user_balances.sort(key=lambda x: x.user_identifier)
|
|
return ListCostSummary(
|
|
list_id=db_list.id,
|
|
list_name=db_list.name,
|
|
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
num_participating_users=num_participating_users,
|
|
equal_share_per_user=equal_share_per_user_for_response, # Use the ideal share for the response field
|
|
user_balances=user_balances
|
|
)
|
|
|
|
@router.get(
|
|
"/groups/{group_id}/balance-summary",
|
|
response_model=GroupBalanceSummary,
|
|
summary="Get Detailed Balance Summary for a Group",
|
|
tags=["Costs", "Groups"],
|
|
responses={
|
|
status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
|
|
status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
|
|
}
|
|
)
|
|
async def get_group_balance_summary(
|
|
group_id: int,
|
|
db: AsyncSession = Depends(get_transactional_session),
|
|
current_user: UserModel = Depends(current_active_user),
|
|
):
|
|
"""
|
|
Retrieves a detailed financial balance summary for all users within a specific group.
|
|
It considers all expenses, their splits, and all settlements recorded for the group.
|
|
The user must be a member of the group to view its balance summary.
|
|
"""
|
|
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
|
|
|
|
# 1. Verify user is a member of the target group
|
|
group_check = await db.execute(
|
|
select(GroupModel)
|
|
.options(selectinload(GroupModel.member_associations))
|
|
.where(GroupModel.id == group_id)
|
|
)
|
|
db_group_for_check = group_check.scalars().first()
|
|
|
|
if not db_group_for_check:
|
|
raise GroupNotFoundError(group_id)
|
|
|
|
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
|
|
if not user_is_member:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
|
|
|
|
# 2. Get all expenses and settlements for the group
|
|
expenses_result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.group_id == group_id)
|
|
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
|
|
)
|
|
expenses = expenses_result.scalars().all()
|
|
|
|
settlements_result = await db.execute(
|
|
select(SettlementModel)
|
|
.where(SettlementModel.group_id == group_id)
|
|
.options(
|
|
selectinload(SettlementModel.paid_by_user),
|
|
selectinload(SettlementModel.paid_to_user)
|
|
)
|
|
)
|
|
settlements = settlements_result.scalars().all()
|
|
|
|
# 3. Calculate user balances
|
|
user_balances_data = {}
|
|
for assoc in db_group_for_check.member_associations:
|
|
if assoc.user:
|
|
user_balances_data[assoc.user.id] = UserBalanceDetail(
|
|
user_id=assoc.user.id,
|
|
user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
|
|
)
|
|
|
|
# Process expenses
|
|
for expense in expenses:
|
|
if expense.paid_by_user_id in user_balances_data:
|
|
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
|
|
|
|
for split in expense.splits:
|
|
if split.user_id in user_balances_data:
|
|
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
|
|
|
|
# Process settlements
|
|
for settlement in settlements:
|
|
if settlement.paid_by_user_id in user_balances_data:
|
|
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
|
|
if settlement.paid_to_user_id in user_balances_data:
|
|
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
|
|
|
|
# Calculate net balances
|
|
final_user_balances = []
|
|
for user_id, data in user_balances_data.items():
|
|
data.net_balance = (
|
|
data.total_paid_for_expenses + data.total_settlements_received
|
|
) - (data.total_share_of_expenses + data.total_settlements_paid)
|
|
|
|
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
|
|
final_user_balances.append(data)
|
|
|
|
# Sort by user identifier
|
|
final_user_balances.sort(key=lambda x: x.user_identifier)
|
|
|
|
# Calculate suggested settlements
|
|
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
|
|
|
return GroupBalanceSummary(
|
|
group_id=db_group_for_check.id,
|
|
group_name=db_group_for_check.name,
|
|
user_balances=final_user_balances,
|
|
suggested_settlements=suggested_settlements
|
|
) |