remove old cost splitter

This commit is contained in:
Mohamad 2025-05-08 14:31:34 +02:00
parent bbb3c3b7df
commit e3024ccd07
4 changed files with 195 additions and 562 deletions

View File

@ -4,13 +4,25 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP
from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import User as UserModel, Group as GroupModel # For get_current_user dependency and Group model
from app.models import (
User as UserModel,
Group as GroupModel,
List as ListModel,
Expense as ExpenseModel,
Item as ItemModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.crud import cost as crud_cost
from app.crud import list as crud_list # For permission checking
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError
logger = logging.getLogger(__name__)
@ -47,28 +59,117 @@ async def get_list_cost_summary(
await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
except ListPermissionError as e:
logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
raise # Re-raise the original exception to be handled by FastAPI
raise
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
raise # Re-raise
raise
# 2. Calculate the cost summary
try:
cost_summary = await crud_cost.calculate_list_cost_summary(db=db, list_id=list_id)
logger.info(f"Successfully generated cost summary for list {list_id} for user {current_user.email}")
return cost_summary
except ListNotFoundError as e:
logger.warning(f"List {list_id} not found during cost summary calculation: {str(e)}")
# This might be redundant if check_list_permission already confirmed list existence,
# but calculate_list_cost_summary also fetches the list.
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e:
logger.error(f"User not found during cost summary calculation for list {list_id}: {str(e)}")
# This indicates a data integrity issue (e.g., list creator or item adder missing)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error generating cost summary for list {list_id} for user {current_user.email}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the cost summary.")
# 2. Get the list with its items and users
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
# 3. Get or create an expense for this list
expense_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.options(selectinload(ExpenseModel.splits))
)
db_expense = expense_result.scalars().first()
if not db_expense:
# Create a new expense for this list
total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
if total_amount == Decimal("0"):
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
# Create expense with ITEM_BASED split type
expense_in = ExpenseCreate(
description=f"Cost summary for list {db_list.name}",
total_amount=total_amount,
list_id=list_id,
split_type=SplitTypeEnum.ITEM_BASED,
paid_by_user_id=current_user.id # Use current user as payer for now
)
db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)
# 4. Calculate cost summary from expense splits
participating_users = set()
user_items_added_value = {}
total_list_cost = Decimal("0.00")
# Get all users who added items
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user:
participating_users.add(item.added_by_user)
user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
total_list_cost += item.price
# Get all users from expense splits
for split in db_expense.splits:
if split.user:
participating_users.add(split.user)
num_participating_users = len(participating_users)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=Decimal("0.00"),
num_participating_users=0,
equal_share_per_user=Decimal("0.00"),
user_balances=[]
)
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances = []
first_user_processed = False
for user in participating_users:
items_added = user_items_added_value.get(user.id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user.name if user.name else user.email
user_balances.append(
UserCostShare(
user_id=user.id,
user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id,
list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
@router.get(
"/groups/{group_id}/balance-summary",
@ -92,11 +193,7 @@ async def get_group_balance_summary(
"""
logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")
# 1. Verify user is a member of the target group (using crud_group.check_group_membership or similar)
# Assuming a function like this exists in app.crud.group or we add it.
# For now, let's placeholder this check logic.
# await crud_group.check_group_membership(db=db, group_id=group_id, user_id=current_user.id)
# A simpler check for now: fetch the group and see if user is part of member_associations
# 1. Verify user is a member of the target group
group_check = await db.execute(
select(GroupModel)
.options(selectinload(GroupModel.member_associations))
@ -109,20 +206,75 @@ async def get_group_balance_summary(
user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
if not user_is_member:
# If ListPermissionError is generic enough for "access resource", use it, or a new GroupPermissionError
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")
# 2. Calculate the group balance summary
try:
balance_summary = await crud_cost.calculate_group_balance_summary(db=db, group_id=group_id)
logger.info(f"Successfully generated balance summary for group {group_id} for user {current_user.email}")
return balance_summary
except GroupNotFoundError as e:
logger.warning(f"Group {group_id} not found during balance summary calculation: {str(e)}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except UserNotFoundError as e: # Should not happen if group members are correctly fetched
logger.error(f"User not found during balance summary for group {group_id}: {str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An internal error occurred finding a user for the summary.")
except Exception as e:
logger.error(f"Unexpected error generating balance summary for group {group_id}: {str(e)}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while generating the group balance summary.")
# 2. Get all expenses and settlements for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
)
expenses = expenses_result.scalars().all()
settlements_result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.options(
selectinload(SettlementModel.paid_by_user),
selectinload(SettlementModel.paid_to_user)
)
)
settlements = settlements_result.scalars().all()
# 3. Calculate user balances
user_balances_data = {}
for assoc in db_group_for_check.member_associations:
if assoc.user:
user_balances_data[assoc.user.id] = UserBalanceDetail(
user_id=assoc.user.id,
user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
)
# Process expenses
for expense in expenses:
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# Process settlements
for settlement in settlements:
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# Calculate net balances
final_user_balances = []
for user_id, data in user_balances_data.items():
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
# Sort by user identifier
final_user_balances.sort(key=lambda x: x.user_identifier)
# Calculate suggested settlements
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=db_group_for_check.id,
group_name=db_group_for_check.name,
user_balances=final_user_balances,
suggested_settlements=suggested_settlements
)

View File

@ -60,6 +60,7 @@ Organic Bananas
CORS_ORIGINS: list[str] = [
"http://localhost:5174",
"http://localhost:8000",
"http://localhost:9000",
# Add your deployed frontend URL here later
# "https://your-frontend-domain.com",
]

View File

@ -1,266 +0,0 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import func as sql_func, or_
from decimal import Decimal, ROUND_HALF_UP
from typing import Dict, Optional, Sequence, List as PyList
from collections import defaultdict
import logging
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
UserGroup as UserGroupModel,
Group as GroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.core.exceptions import (
ListNotFoundError,
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError
)
logger = logging.getLogger(__name__)
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list based purely on item prices and who added them.
This is a simpler calculation and does not involve the Expense/Settlement system.
"""
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list: Optional[ListModel] = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
participating_users_map: Dict[int, UserModel] = {}
if db_list.group:
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user:
participating_users_map[ug_assoc.user.id] = ug_assoc.user
elif db_list.creator: # Personal list
participating_users_map[db_list.creator.id] = db_list.creator
# Include all users who added items with prices, even if not in the primary context (group/creator)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
participating_users_map[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users_map)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
)
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in participating_users_map:
user_items_added_value[item.added_by_id] += item.price
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances: PyList[UserCostShare] = []
first_user_processed = False
for user_id, user_obj in participating_users_map.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user_obj.name if user_obj.name else user_obj.email
user_balances.append(
UserCostShare(
user_id=user_id, user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
# --- Helper for Settlement Suggestions ---
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
"""
Calculates a list of suggested settlements to resolve group debts.
Uses a greedy algorithm to minimize the number of transactions.
Input: List of UserBalanceDetail objects with calculated net_balances.
Output: List of SuggestedSettlement objects.
"""
# Use a small tolerance for floating point comparisons with Decimal
tolerance = Decimal("0.001")
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
settlements: PyList[SuggestedSettlement] = []
debtor_idx = 0
creditor_idx = 0
# Create mutable copies of balances to track remaining amounts
debtor_balances = {d.user_id: d.net_balance for d in debtors}
creditor_balances = {c.user_id: c.net_balance for c in creditors}
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
debtor = debtors[debtor_idx]
creditor = creditors[creditor_idx]
debtor_remaining = debtor_balances[debtor.user_id]
creditor_remaining = creditor_balances[creditor.user_id]
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
if transfer_amount > tolerance: # Only record meaningful transfers
settlements.append(SuggestedSettlement(
from_user_id=debtor.user_id,
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
to_user_id=creditor.user_id,
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
amount=transfer_amount
))
# Update remaining balances
debtor_balances[debtor.user_id] += transfer_amount
creditor_balances[creditor.user_id] -= transfer_amount
# Move to next debtor if current one is settled (or very close)
if abs(debtor_balances[debtor.user_id]) < tolerance:
debtor_idx += 1
# Move to next creditor if current one is settled (or very close)
if creditor_balances[creditor.user_id] < tolerance:
creditor_idx += 1
# Log if lists aren't empty - indicates potential imbalance or rounding issue
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
# Calculate remaining balances for logging
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
return settlements
# --- NEW: Detailed Group Balance Summary ---
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
"""
Calculates a detailed balance summary for all users in a group,
considering all expenses, splits, and settlements within that group.
Also calculates suggested settlements.
"""
group = await db.get(GroupModel, group_id)
if not group:
raise GroupNotFoundError(group_id)
# 1. Get all group members
group_members_result = await db.execute(
select(UserModel)
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
.where(UserGroupModel.group_id == group_id)
)
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
if not group_members:
return GroupBalanceSummary(
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
)
user_balances_data: Dict[int, UserBalanceDetail] = {}
for user_id, user_obj in group_members.items():
user_balances_data[user_id] = UserBalanceDetail(
user_id=user_id,
user_identifier=user_obj.name if user_obj.name else user_obj.email
)
overall_total_expenses = Decimal("0.00")
overall_total_settlements = Decimal("0.00")
# 2. Process Expenses and ExpenseSplits for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits))
)
for expense in expenses_result.scalars().all():
overall_total_expenses += expense.total_amount
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# 3. Process Settlements for the group
settlements_result = await db.execute(
select(SettlementModel).where(SettlementModel.group_id == group_id)
)
for settlement in settlements_result.scalars().all():
overall_total_settlements += settlement.amount
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# 4. Calculate net balances and prepare final list
final_user_balances: PyList[UserBalanceDetail] = []
for user_id in group_members.keys():
data = user_balances_data[user_id]
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
final_user_balances.sort(key=lambda x: x.user_identifier)
# 5. Calculate suggested settlements (NEW)
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=group.id,
group_name=group.name,
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=final_user_balances,
suggested_settlements=suggested_settlements # Add suggestions to response
)

View File

@ -1,254 +0,0 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Dict
import logging # For mocking the logger
from app.crud.cost import (
calculate_list_cost_summary,
calculate_suggested_settlements,
calculate_group_balance_summary
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
# Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.execute = AsyncMock()
session.get = AsyncMock()
return session
@pytest.fixture
def user_a():
return UserModel(id=1, name="User A", email="a@example.com")
@pytest.fixture
def user_b():
return UserModel(id=2, name="User B", email="b@example.com")
@pytest.fixture
def user_c():
return UserModel(id=3, name="User C", email="c@example.com")
@pytest.fixture
def basic_list_model(user_a):
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
@pytest.fixture
def group_model_with_members(user_a, user_b):
group = GroupModel(id=1, name="Test Group")
# Simulate user_associations for calculate_list_cost_summary if list is group-based
group.user_associations = [
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
]
return group
@pytest.fixture
def list_model_with_group(group_model_with_members):
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
# --- calculate_list_cost_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
basic_list_model.items = [] # Ensure no items
basic_list_model.group = None # Ensure it's a personal list
basic_list_model.creator = user_a
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
assert summary.list_id == basic_list_model.id
assert summary.total_list_cost == Decimal("0.00")
assert summary.num_participating_users == 1
assert summary.equal_share_per_user == Decimal("0.00")
assert len(summary.user_balances) == 1
assert summary.user_balances[0].user_id == user_a.id
assert summary.user_balances[0].balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
list_model_with_group.items = [item1, item2]
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
assert summary.total_list_cost == Decimal("5.00")
assert summary.num_participating_users == 2
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
assert balances_map[user_a.id].amount_due == Decimal("2.50")
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
assert balances_map[user_b.id].amount_due == Decimal("2.50")
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
@pytest.mark.asyncio
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
with pytest.raises(ListNotFoundError):
await calculate_list_cost_summary(mock_db_session, 999)
# --- calculate_suggested_settlements Tests ---
@patch('app.crud.cost.logger') # Mock the logger used in the function
def test_calculate_suggested_settlements_simple_case(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
]
settlements = calculate_suggested_settlements(user_balances)
assert len(settlements) == 1
assert settlements[0].from_user_id == 1
assert settlements[0].to_user_id == 2
assert settlements[0].amount == Decimal("10.00")
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_multiple(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
]
settlements = calculate_suggested_settlements(user_balances)
# Expected: A owes B 10, A owes C 20
assert len(settlements) == 2
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
mock_logger.warning.assert_not_called()
@patch('app.crud.cost.logger')
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
user_balances = [
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
]
settlements = calculate_suggested_settlements(user_balances)
# A -> C 33.33, B -> C 33.33
assert len(settlements) == 2
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
assert total_settled_to_C == Decimal("66.66")
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
# --- calculate_group_balance_summary Tests ---
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
mock_db_session.get.return_value = group_model_with_members # For group fetch
# Mock group members query
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
# Mock expenses query (no expenses)
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = []
# Mock settlements query (no settlements)
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = []
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
assert summary.group_id == group_model_with_members.id
assert len(summary.user_balances) == 2
assert len(summary.suggested_settlements) == 0
for ub in summary.user_balances:
assert ub.net_balance == Decimal("0.00")
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
# Group members for this test: A, B, C
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
all_users = [user_a, user_b, user_c]
mock_db_session.get.return_value = group_model_with_members
mock_members_result = AsyncMock()
mock_members_result.scalars.return_value.all.return_value = all_users
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
# A: paid 90, share 30 -> +60
# B: paid 0, share 30 -> -30
# C: paid 0, share 30 -> -30
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
expense1.splits = [
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
]
mock_expenses_result = AsyncMock()
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
# Settlements: B paid A 10
# A: received 10 -> +60 + 10 = +70
# B: paid 10 -> -30 - 10 = -40
# C: no settlement -> -30
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
mock_settlements_result = AsyncMock()
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
balances_map = {ub.user_id: ub for ub in summary.user_balances}
assert balances_map[user_a.id].net_balance == Decimal("70.00")
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
# Suggested settlements: B owes A 30 (40-10), C owes A 30
# Net balances: A: +70, B: -40, C: -30
# Algorithm should suggest: B -> A (40), C -> A (30)
assert len(summary.suggested_settlements) == 2
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
# Order might vary, check presence and amounts
assert (user_b.id, user_a.id) in settlement_map
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
assert (user_c.id, user_a.id) in settlement_map
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
mock_settlement_logger.warning.assert_not_called()
@pytest.mark.asyncio
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
mock_db_session.get.return_value = None
with pytest.raises(GroupNotFoundError):
await calculate_group_balance_summary(mock_db_session, 999)
# TODO:
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
# - Test calculate_list_cost_summary with remainder distribution for equal share.
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
# - Test calculate_suggested_settlements with logger warning for remaining balances.
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).