254 lines
12 KiB
Python
254 lines
12 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from typing import List as PyList, Dict
|
|
import logging # For mocking the logger
|
|
|
|
from app.crud.cost import (
|
|
calculate_list_cost_summary,
|
|
calculate_suggested_settlements,
|
|
calculate_group_balance_summary
|
|
)
|
|
from app.schemas.cost import (
|
|
ListCostSummary,
|
|
UserCostShare,
|
|
GroupBalanceSummary,
|
|
UserBalanceDetail,
|
|
SuggestedSettlement
|
|
)
|
|
from app.models import (
|
|
List as ListModel,
|
|
Item as ItemModel,
|
|
User as UserModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel,
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
Settlement as SettlementModel
|
|
)
|
|
from app.core.exceptions import ListNotFoundError, GroupNotFoundError
|
|
|
|
# Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.execute = AsyncMock()
|
|
session.get = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def user_a():
|
|
return UserModel(id=1, name="User A", email="a@example.com")
|
|
|
|
@pytest.fixture
|
|
def user_b():
|
|
return UserModel(id=2, name="User B", email="b@example.com")
|
|
|
|
@pytest.fixture
|
|
def user_c():
|
|
return UserModel(id=3, name="User C", email="c@example.com")
|
|
|
|
@pytest.fixture
|
|
def basic_list_model(user_a):
|
|
return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[])
|
|
|
|
@pytest.fixture
|
|
def group_model_with_members(user_a, user_b):
|
|
group = GroupModel(id=1, name="Test Group")
|
|
# Simulate user_associations for calculate_list_cost_summary if list is group-based
|
|
group.user_associations = [
|
|
UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a),
|
|
UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b)
|
|
]
|
|
return group
|
|
|
|
@pytest.fixture
|
|
def list_model_with_group(group_model_with_members):
|
|
return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[])
|
|
|
|
# --- calculate_list_cost_summary Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a):
|
|
mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model
|
|
basic_list_model.items = [] # Ensure no items
|
|
basic_list_model.group = None # Ensure it's a personal list
|
|
basic_list_model.creator = user_a
|
|
|
|
summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id)
|
|
|
|
assert summary.list_id == basic_list_model.id
|
|
assert summary.total_list_cost == Decimal("0.00")
|
|
assert summary.num_participating_users == 1
|
|
assert summary.equal_share_per_user == Decimal("0.00")
|
|
assert len(summary.user_balances) == 1
|
|
assert summary.user_balances[0].user_id == user_a.id
|
|
assert summary.user_balances[0].balance == Decimal("0.00")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b):
|
|
item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id)
|
|
item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id)
|
|
list_model_with_group.items = [item1, item2]
|
|
|
|
mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group
|
|
|
|
summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id)
|
|
|
|
assert summary.total_list_cost == Decimal("5.00")
|
|
assert summary.num_participating_users == 2
|
|
assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2
|
|
|
|
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
|
assert balances_map[user_a.id].items_added_value == Decimal("3.00")
|
|
assert balances_map[user_a.id].amount_due == Decimal("2.50")
|
|
assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50
|
|
|
|
assert balances_map[user_b.id].items_added_value == Decimal("2.00")
|
|
assert balances_map[user_b.id].amount_due == Decimal("2.50")
|
|
assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_list_cost_summary_list_not_found(mock_db_session):
|
|
mock_db_session.execute.return_value.scalars.return_value.first.return_value = None
|
|
with pytest.raises(ListNotFoundError):
|
|
await calculate_list_cost_summary(mock_db_session, 999)
|
|
|
|
# --- calculate_suggested_settlements Tests ---
|
|
@patch('app.crud.cost.logger') # Mock the logger used in the function
|
|
def test_calculate_suggested_settlements_simple_case(mock_logger):
|
|
user_balances = [
|
|
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")),
|
|
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
|
]
|
|
settlements = calculate_suggested_settlements(user_balances)
|
|
assert len(settlements) == 1
|
|
assert settlements[0].from_user_id == 1
|
|
assert settlements[0].to_user_id == 2
|
|
assert settlements[0].amount == Decimal("10.00")
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
@patch('app.crud.cost.logger')
|
|
def test_calculate_suggested_settlements_multiple(mock_logger):
|
|
user_balances = [
|
|
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")),
|
|
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")),
|
|
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")),
|
|
]
|
|
settlements = calculate_suggested_settlements(user_balances)
|
|
# Expected: A owes B 10, A owes C 20
|
|
assert len(settlements) == 2
|
|
s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements}
|
|
assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first)
|
|
assert s_map[(1, 2)] == Decimal("10.00") # A -> B
|
|
mock_logger.warning.assert_not_called()
|
|
|
|
@patch('app.crud.cost.logger')
|
|
def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger):
|
|
user_balances = [
|
|
UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")),
|
|
UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")),
|
|
UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")),
|
|
]
|
|
settlements = calculate_suggested_settlements(user_balances)
|
|
# A -> C 33.33, B -> C 33.33
|
|
assert len(settlements) == 2
|
|
total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3)
|
|
assert total_settled_to_C == Decimal("66.66")
|
|
mock_logger.warning.assert_not_called() # Assuming exact match after quantization
|
|
|
|
# --- calculate_group_balance_summary Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b):
|
|
mock_db_session.get.return_value = group_model_with_members # For group fetch
|
|
|
|
# Mock group members query
|
|
mock_members_result = AsyncMock()
|
|
mock_members_result.scalars.return_value.all.return_value = [user_a, user_b]
|
|
|
|
# Mock expenses query (no expenses)
|
|
mock_expenses_result = AsyncMock()
|
|
mock_expenses_result.scalars.return_value.all.return_value = []
|
|
|
|
# Mock settlements query (no settlements)
|
|
mock_settlements_result = AsyncMock()
|
|
mock_settlements_result.scalars.return_value.all.return_value = []
|
|
|
|
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
|
|
|
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
|
|
|
assert summary.group_id == group_model_with_members.id
|
|
assert len(summary.user_balances) == 2
|
|
assert len(summary.suggested_settlements) == 0
|
|
for ub in summary.user_balances:
|
|
assert ub.net_balance == Decimal("0.00")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c):
|
|
# Group members for this test: A, B, C
|
|
group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c))
|
|
all_users = [user_a, user_b, user_c]
|
|
|
|
mock_db_session.get.return_value = group_model_with_members
|
|
|
|
mock_members_result = AsyncMock()
|
|
mock_members_result.scalars.return_value.all.return_value = all_users
|
|
|
|
# Expenses: A paid 90, split equally (A, B, C -> 30 each)
|
|
# A: paid 90, share 30 -> +60
|
|
# B: paid 0, share 30 -> -30
|
|
# C: paid 0, share 30 -> -30
|
|
expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00"))
|
|
expense1.splits = [
|
|
ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")),
|
|
ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")),
|
|
ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")),
|
|
]
|
|
mock_expenses_result = AsyncMock()
|
|
mock_expenses_result.scalars.return_value.all.return_value = [expense1]
|
|
|
|
# Settlements: B paid A 10
|
|
# A: received 10 -> +60 + 10 = +70
|
|
# B: paid 10 -> -30 - 10 = -40
|
|
# C: no settlement -> -30
|
|
settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00"))
|
|
mock_settlements_result = AsyncMock()
|
|
mock_settlements_result.scalars.return_value.all.return_value = [settlement1]
|
|
|
|
mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result]
|
|
|
|
with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements
|
|
summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id)
|
|
|
|
balances_map = {ub.user_id: ub for ub in summary.user_balances}
|
|
assert balances_map[user_a.id].net_balance == Decimal("70.00")
|
|
assert balances_map[user_b.id].net_balance == Decimal("-40.00")
|
|
assert balances_map[user_c.id].net_balance == Decimal("-30.00")
|
|
|
|
# Suggested settlements: B owes A 30 (40-10), C owes A 30
|
|
# Net balances: A: +70, B: -40, C: -30
|
|
# Algorithm should suggest: B -> A (40), C -> A (30)
|
|
assert len(summary.suggested_settlements) == 2
|
|
settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements}
|
|
|
|
# Order might vary, check presence and amounts
|
|
assert (user_b.id, user_a.id) in settlement_map
|
|
assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00")
|
|
|
|
assert (user_c.id, user_a.id) in settlement_map
|
|
assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00")
|
|
mock_settlement_logger.warning.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calculate_group_balance_summary_group_not_found(mock_db_session):
|
|
mock_db_session.get.return_value = None
|
|
with pytest.raises(GroupNotFoundError):
|
|
await calculate_group_balance_summary(mock_db_session, 999)
|
|
|
|
# TODO:
|
|
# - Test calculate_list_cost_summary with item added by user not in group/not creator.
|
|
# - Test calculate_list_cost_summary with remainder distribution for equal share.
|
|
# - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc.
|
|
# - Test calculate_suggested_settlements with logger warning for remaining balances.
|
|
# - Test calculate_group_balance_summary with more complex expense/settlement scenarios.
|
|
# - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean). |