remove old cost splitter

This commit is contained in:
Mohamad 2025-05-08 14:31:34 +02:00
parent bbb3c3b7df
commit e3024ccd07
4 changed files with 195 additions and 562 deletions

View File

@ -4,13 +4,25 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP
from app.database import get_db from app.database import get_db
from app.api.dependencies import get_current_user from app.api.dependencies import get_current_user
from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model from app.models import (
User as UserModel,
Group as GroupModel,
List as ListModel,
Expense as ExpenseModel,
Item as ItemModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.crud import cost as crud_cost from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list # For permission checking from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,28 +59,117 @@ async def get_list_cost_summary(
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e: except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}") logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise # Re-raise the original exception to be handled by FastAPI raise
except ListNotFoundError as e: except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}") logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise # Re-raise raise
# 2. Calculate the cost summary # 2. Get the list with its items and users
try: list_result = await db.execute(
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id) select(ListModel)
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}") .options(
return cost_summary selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
except ListNotFoundError as e: selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}") selectinload(ListModel.creator)
# This might be redundant if check_list_permission already confirmed list existence, )
# but calculate_list_cost_summary also fetches the list. .where(ListModel.id == list_id)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) )
except UserNotFoundError as e: db_list = list_result.scalars().first()
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}") if not db_list:
# This indicates a data integrity issue (e.g., list creator or item adder missing) raise ListNotFoundError(list_id)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e: # 3. Get or create an expense for this list
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True) expense_result = await db.execute(
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.") select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.options(selectinload(ExpenseModel.splits))
)
db_expense = expense_result.scalars().first()
if not db_expense:
# Create a new expense for this list
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
if total_amount == Decimal("0"):
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# Create expense with ITEM_BASED split type
expense_in = ExpenseCreate(
description=f"Cost summary for list {db_list.name}",
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=current_user.id # Use current user as payer for now
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
# 4. Calculate cost summary from expense splits
participating_users = set()
user_items_added_value = {}
total_list_cost = Decimal("0.00")
# Get all users who added items
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
participating_users.add(item.added_by_user)
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
total_list_cost += item.price
# Get all users from expense splits
for split in db_expense.splits:
if split.user:
participating_users.add(split.user)
num_participating_users = len(participating_users)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances = []
first_user_processed = False
for user in participating_users:
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
user_balances.append(
UserCostShare(
user_id=user.id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
@router.get( @router.get(
"/groups/{group_id}/balance-summary", "/groups/{group_id}/balance-summary",
@ -92,11 +193,7 @@ async def get_group_balance_summary(
""" """
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}") logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
# 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar) # 1. Verify user is a member of the target group
# Assuming a function like this exists in app.crud.group or we add it.
# For now, let's placeholder this check logic.
# await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id)
# A simpler check for now: fetch the group and see if user is part of member_associations
group_check = await db.execute( group_check = await db.execute(
select(GroupModel) select(GroupModel)
.options(selectinload(GroupModel.member_associations)) .options(selectinload(GroupModel.member_associations))
@ -109,20 +206,75 @@ async def get_group_balance_summary(
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations) user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
if not user_is_member: if not user_is_member:
# If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
# 2. Calculate the group balance summary # 2. Get all expenses and settlements for the group
try: expenses_result = await db.execute(
balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id) select(ExpenseModel)
logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}") .where(ExpenseModel.group_id == group_id)
return balance_summary .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
except GroupNotFoundError as e: )
logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}") expenses = expenses_result.scalars().all()
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e: # Should not happen if group members are correctly fetched settlements_result = await db.execute(
logger.error(f"User not found during balance summary for group {group_id}: {str(e)}") select(SettlementModel)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.") .where(SettlementModel.group_id == group_id)
except Exception as e: .options(
logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True) selectinload(SettlementModel.paid_by_user),
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.") selectinload(SettlementModel.paid_to_user)
)
)
settlements = settlements_result.scalars().all()
# 3. Calculate user balances
user_balances_data = {}
for assoc in db_group_for_check.member_associations:
if assoc.user:
user_balances_data[assoc.user.id] = UserBalanceDetail(
user_id=assoc.user.id,
user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
)
# Process expenses
for expense in expenses:
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# Process settlements
for settlement in settlements:
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# Calculate net balances
final_user_balances = []
for user_id, data in user_balances_data.items():
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
# Sort by user identifier
final_user_balances.sort(key=lambda x: x.user_identifier)
# Calculate suggested settlements
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=db_group_for_check.id,
group_name=db_group_for_check.name,
user_balances=final_user_balances,
suggested_settlements=suggested_settlements
)

View File

@ -60,6 +60,7 @@ Organic Bananas
CORS_ORIGINS: list[str] = [ CORS_ORIGINS: list[str] = [
"http://localhost:5174", "http://localhost:5174",
"http://localhost:8000", "http://localhost:8000",
"http://localhost:9000",
# Add your deployed frontend URL here later # Add your deployed frontend URL here later
# "https://your-frontend-domain.com", # "https://your-frontend-domain.com",
] ]

View File

@ -1,266 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import func as sql_func, or_
from decimal import Decimal, ROUND_HALF_UP
from typing import Dict, Optional, Sequence, List as PyList
from collections import defaultdict
import logging
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
UserGroup as UserGroupModel,
Group as GroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.core.exceptions import (
ListNotFoundError,
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError
)
logger = logging.getLogger(__name__)
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list based purely on item prices and who added them.
This is a simpler calculation and does not involve the Expense/Settlement system.
"""
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list: Optional[ListModel] = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
participating_users_map: Dict[int, UserModel] = {}
if db_list.group:
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user:
participating_users_map[ug_assoc.user.id] = ug_assoc.user
elif db_list.creator: # Personal list
participating_users_map[db_list.creator.id] = db_list.creator
# Include all users who added items with prices, even if not in the primary context (group/creator)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
participating_users_map[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users_map)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
)
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in participating_users_map:
user_items_added_value[item.added_by_id] += item.price
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances: PyList[UserCostShare] = []
first_user_processed = False
for user_id, user_obj in participating_users_map.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user_obj.name if user_obj.name else user_obj.email
user_balances.append(
UserCostShare(
user_id=user_id, user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
# --- Helper for Settlement Suggestions ---
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
"""
Calculates a list of suggested settlements to resolve group debts.
Uses a greedy algorithm to minimize the number of transactions.
Input: List of UserBalanceDetail objects with calculated net_balances.
Output: List of SuggestedSettlement objects.
"""
# Use a small tolerance for floating point comparisons with Decimal
tolerance = Decimal("0.001")
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
settlements: PyList[SuggestedSettlement] = []
debtor_idx = 0
creditor_idx = 0
# Create mutable copies of balances to track remaining amounts
debtor_balances = {d.user_id: d.net_balance for d in debtors}
creditor_balances = {c.user_id: c.net_balance for c in creditors}
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
debtor = debtors[debtor_idx]
creditor = creditors[creditor_idx]
debtor_remaining = debtor_balances[debtor.user_id]
creditor_remaining = creditor_balances[creditor.user_id]
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
if transfer_amount > tolerance: # Only record meaningful transfers
settlements.append(SuggestedSettlement(
from_user_id=debtor.user_id,
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
to_user_id=creditor.user_id,
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
amount=transfer_amount
))
# Update remaining balances
debtor_balances[debtor.user_id] += transfer_amount
creditor_balances[creditor.user_id] -= transfer_amount
# Move to next debtor if current one is settled (or very close)
if abs(debtor_balances[debtor.user_id]) < tolerance:
debtor_idx += 1
# Move to next creditor if current one is settled (or very close)
if creditor_balances[creditor.user_id] < tolerance:
creditor_idx += 1
# Log if lists aren't empty - indicates potential imbalance or rounding issue
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
# Calculate remaining balances for logging
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
return settlements
# --- NEW: Detailed Group Balance Summary ---
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
"""
Calculates a detailed balance summary for all users in a group,
considering all expenses, splits, and settlements within that group.
Also calculates suggested settlements.
"""
group = await db.get(GroupModel, group_id)
if not group:
raise GroupNotFoundError(group_id)
# 1. Get all group members
group_members_result = await db.execute(
select(UserModel)
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
.where(UserGroupModel.group_id == group_id)
)
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
if not group_members:
return GroupBalanceSummary(
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
)
user_balances_data: Dict[int, UserBalanceDetail] = {}
for user_id, user_obj in group_members.items():
user_balances_data[user_id] = UserBalanceDetail(
user_id=user_id,
user_identifier=user_obj.name if user_obj.name else user_obj.email
)
overall_total_expenses = Decimal("0.00")
overall_total_settlements = Decimal("0.00")
# 2. Process Expenses and ExpenseSplits for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits))
)
for expense in expenses_result.scalars().all():
overall_total_expenses += expense.total_amount
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# 3. Process Settlements for the group
settlements_result = await db.execute(
select(SettlementModel).where(SettlementModel.group_id == group_id)
)
for settlement in settlements_result.scalars().all():
overall_total_settlements += settlement.amount
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# 4. Calculate net balances and prepare final list
final_user_balances: PyList[UserBalanceDetail] = []
for user_id in group_members.keys():
data = user_balances_data[user_id]
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
final_user_balances.sort(key=lambda x: x.user_identifier)
# 5. Calculate suggested settlements (NEW)
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=group.id,
group_name=group.name,
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=final_user_balances,
suggested_settlements=suggested_settlements # Add suggestions to response
)

View File

@ -1,254 +0,0 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict
import logging # For mocking the logger
from app.crud.cost import (
calculate_list_cost_summary,
calculate_suggested_settlements,
calculate_group_balance_summary
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.execute = AsyncMock()
session.get = AsyncMock()
return session
@pytest.fixture
def user_a():
return UserModel(id=1, name="User A", email="a@example.com")
@pytest.fixture
def user_b():
return UserModel(id=2, name="User B", email="b@example.com")
@pytest.fixture
def user_c():
return UserModel(id=3, name="User C", email="c@example.com")
@pytest.fixture
def basic_list_model(user_a):
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
@pytest.fixture
def group_model_with_members(user_a, user_b):
group = GroupModel(id=1, name="Test Group")
# Simulate user_associations for calculate_list_cost_summary if list is group-based
group.user_associations = [
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
]
return group
@pytest.fixture
def list_model_with_group(group_model_with_members):
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
# --- calculate_list_cost_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
basic_list_model.items = [] # Ensure no items
basic_list_model.group = None # Ensure it's a personal list
basic_list_model.creator = user_a
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
assert summary.list_id == basic_list_model.id
assert summary.total_list_cost == Decimal("0.00")
assert summary.num_participating_users == 1
assert summary.equal_share_per_user == Decimal("0.00")
assert len(summary.user_balances) == 1
assert summary.user_balances[0].user_id == user_a.id
assert summary.user_balances[0].balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
list_model_with_group.items = [item1, item2]
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
assert summary.total_list_cost == Decimal("5.00")
assert summary.num_participating_users == 2
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
assert balances_map[user_a.id].amount_due == Decimal("2.50")
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
assert balances_map[user_b.id].amount_due == Decimal("2.50")
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
with pytest.raises(ListNotFoundError):
await calculate_list_cost_summary(mock_db_session, 999)
# --- calculate_suggested_settlements Tests ---
@patch('app.crud.cost.logger') # Mock the logger used in the function
def test_calculate_suggested_settlements_simple_case(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
]
settlements = calculate_suggested_settlements(user_balances)
assert len(settlements) == 1
assert settlements[0].from_user_id == 1
assert settlements[0].to_user_id == 2
assert settlements[0].amount == Decimal("10.00")
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_multiple(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
]
settlements = calculate_suggested_settlements(user_balances)
# Expected: A owes B 10, A owes C 20
assert len(settlements) == 2
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
]
settlements = calculate_suggested_settlements(user_balances)
# A -> C 33.33, B -> C 33.33
assert len(settlements) == 2
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
assert total_settled_to_C == Decimal("66.66")
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
# --- calculate_group_balance_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
mock_db_session.get.return_value = group_model_with_members # For group fetch
# Mock group members query
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
# Mock expenses query (no expenses)
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = []
# Mock settlements query (no settlements)
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = []
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
assert summary.group_id == group_model_with_members.id
assert len(summary.user_balances) == 2
assert len(summary.suggested_settlements) == 0
for ub in summary.user_balances:
assert ub.net_balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
# Group members for this test: A, B, C
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
all_users = [user_a, user_b, user_c]
mock_db_session.get.return_value = group_model_with_members
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = all_users
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
# A: paid 90, share 30 -> +60
# B: paid 0, share 30 -> -30
# C: paid 0, share 30 -> -30
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
expense1.splits = [
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
]
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
# Settlements: B paid A 10
# A: received 10 -> +60 + 10 = +70
# B: paid 10 -> -30 - 10 = -40
# C: no settlement -> -30
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].net_balance == Decimal("70.00")
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
# Suggested settlements: B owes A 30 (40-10), C owes A 30
# Net balances: A: +70, B: -40, C: -30
# Algorithm should suggest: B -> A (40), C -> A (30)
assert len(summary.suggested_settlements) == 2
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
# Order might vary, check presence and amounts
assert (user_b.id, user_a.id) in settlement_map
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
assert (user_c.id, user_a.id) in settlement_map
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
mock_settlement_logger.warning.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
mock_db_session.get.return_value = None
with pytest.raises(GroupNotFoundError):
await calculate_group_balance_summary(mock_db_session, 999)
# TODO:
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
# - Test calculate_list_cost_summary with remainder distribution for equal share.
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
# - Test calculate_suggested_settlements with logger warning for remaining balances.
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).