import pytest from unittest.mock import AsyncMock, MagicMock, patch from decimal import Decimal, ROUND_HALF_UP from typing import List as PyList, Dict import logging # For mocking the logger from app.crud.cost import ( calculate_list_cost_summary, calculate_suggested_settlements, calculate_group_balance_summary ) from app.schemas.cost import ( ListCostSummary, UserCostShare, GroupBalanceSummary, UserBalanceDetail, SuggestedSettlement ) from app.models import ( List as ListModel, Item as ItemModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel, Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, Settlement as SettlementModel ) from app.core.exceptions import ListNotFoundError, GroupNotFoundError # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.execute = AsyncMock() session.get = AsyncMock() return session @pytest.fixture def user_a(): return UserModel(id=1, name="User A", email="a@example.com") @pytest.fixture def user_b(): return UserModel(id=2, name="User B", email="b@example.com") @pytest.fixture def user_c(): return UserModel(id=3, name="User C", email="c@example.com") @pytest.fixture def basic_list_model(user_a): return ListModel(id=1, name="Test List", creator_id=user_a.id, creator=user_a, items=[]) @pytest.fixture def group_model_with_members(user_a, user_b): group = GroupModel(id=1, name="Test Group") # Simulate user_associations for calculate_list_cost_summary if list is group-based group.user_associations = [ UserGroupModel(user_id=user_a.id, group_id=group.id, user=user_a), UserGroupModel(user_id=user_b.id, group_id=group.id, user=user_b) ] return group @pytest.fixture def list_model_with_group(group_model_with_members): return ListModel(id=2, name="Group List", group_id=group_model_with_members.id, group=group_model_with_members, items=[]) # --- calculate_list_cost_summary Tests --- @pytest.mark.asyncio async def test_calculate_list_cost_summary_personal_list_no_items(mock_db_session, basic_list_model, user_a): mock_db_session.execute.return_value.scalars.return_value.first.return_value = basic_list_model basic_list_model.items = [] # Ensure no items basic_list_model.group = None # Ensure it's a personal list basic_list_model.creator = user_a summary = await calculate_list_cost_summary(mock_db_session, basic_list_model.id) assert summary.list_id == basic_list_model.id assert summary.total_list_cost == Decimal("0.00") assert summary.num_participating_users == 1 assert summary.equal_share_per_user == Decimal("0.00") assert len(summary.user_balances) == 1 assert summary.user_balances[0].user_id == user_a.id assert summary.user_balances[0].balance == Decimal("0.00") @pytest.mark.asyncio async def test_calculate_list_cost_summary_group_list_with_items(mock_db_session, list_model_with_group, group_model_with_members, user_a, user_b): item1 = ItemModel(id=1, name="Milk", price=Decimal("3.00"), added_by_id=user_a.id, added_by_user=user_a, list_id=list_model_with_group.id) item2 = ItemModel(id=2, name="Bread", price=Decimal("2.00"), added_by_id=user_b.id, added_by_user=user_b, list_id=list_model_with_group.id) list_model_with_group.items = [item1, item2] mock_db_session.execute.return_value.scalars.return_value.first.return_value = list_model_with_group summary = await calculate_list_cost_summary(mock_db_session, list_model_with_group.id) assert summary.total_list_cost == Decimal("5.00") assert summary.num_participating_users == 2 assert summary.equal_share_per_user == Decimal("2.50") # 5.00 / 2 balances_map = {ub.user_id: ub for ub in summary.user_balances} assert balances_map[user_a.id].items_added_value == Decimal("3.00") assert balances_map[user_a.id].amount_due == Decimal("2.50") assert balances_map[user_a.id].balance == Decimal("0.50") # 3.00 - 2.50 assert balances_map[user_b.id].items_added_value == Decimal("2.00") assert balances_map[user_b.id].amount_due == Decimal("2.50") assert balances_map[user_b.id].balance == Decimal("-0.50") # 2.00 - 2.50 @pytest.mark.asyncio async def test_calculate_list_cost_summary_list_not_found(mock_db_session): mock_db_session.execute.return_value.scalars.return_value.first.return_value = None with pytest.raises(ListNotFoundError): await calculate_list_cost_summary(mock_db_session, 999) # --- calculate_suggested_settlements Tests --- @patch('app.crud.cost.logger') # Mock the logger used in the function def test_calculate_suggested_settlements_simple_case(mock_logger): user_balances = [ UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-10.00")), UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")), ] settlements = calculate_suggested_settlements(user_balances) assert len(settlements) == 1 assert settlements[0].from_user_id == 1 assert settlements[0].to_user_id == 2 assert settlements[0].amount == Decimal("10.00") mock_logger.warning.assert_not_called() @patch('app.crud.cost.logger') def test_calculate_suggested_settlements_multiple(mock_logger): user_balances = [ UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-30.00")), UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("10.00")), UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("20.00")), ] settlements = calculate_suggested_settlements(user_balances) # Expected: A owes B 10, A owes C 20 assert len(settlements) == 2 s_map = {(s.from_user_id, s.to_user_id): s.amount for s in settlements} assert s_map[(1, 3)] == Decimal("20.00") # A -> C (largest creditor first) assert s_map[(1, 2)] == Decimal("10.00") # A -> B mock_logger.warning.assert_not_called() @patch('app.crud.cost.logger') def test_calculate_suggested_settlements_complex_case_with_rounding_check(mock_logger): user_balances = [ UserBalanceDetail(user_id=1, user_identifier="User A", net_balance=Decimal("-33.33")), UserBalanceDetail(user_id=2, user_identifier="User B", net_balance=Decimal("-33.33")), UserBalanceDetail(user_id=3, user_identifier="User C", net_balance=Decimal("66.66")), ] settlements = calculate_suggested_settlements(user_balances) # A -> C 33.33, B -> C 33.33 assert len(settlements) == 2 total_settled_to_C = sum(s.amount for s in settlements if s.to_user_id == 3) assert total_settled_to_C == Decimal("66.66") mock_logger.warning.assert_not_called() # Assuming exact match after quantization # --- calculate_group_balance_summary Tests --- @pytest.mark.asyncio async def test_calculate_group_balance_summary_no_activity(mock_db_session, group_model_with_members, user_a, user_b): mock_db_session.get.return_value = group_model_with_members # For group fetch # Mock group members query mock_members_result = AsyncMock() mock_members_result.scalars.return_value.all.return_value = [user_a, user_b] # Mock expenses query (no expenses) mock_expenses_result = AsyncMock() mock_expenses_result.scalars.return_value.all.return_value = [] # Mock settlements query (no settlements) mock_settlements_result = AsyncMock() mock_settlements_result.scalars.return_value.all.return_value = [] mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result] summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id) assert summary.group_id == group_model_with_members.id assert len(summary.user_balances) == 2 assert len(summary.suggested_settlements) == 0 for ub in summary.user_balances: assert ub.net_balance == Decimal("0.00") @pytest.mark.asyncio async def test_calculate_group_balance_summary_with_expenses_and_settlements(mock_db_session, group_model_with_members, user_a, user_b, user_c): # Group members for this test: A, B, C group_model_with_members.user_associations.append(UserGroupModel(user_id=user_c.id, group_id=group_model_with_members.id, user=user_c)) all_users = [user_a, user_b, user_c] mock_db_session.get.return_value = group_model_with_members mock_members_result = AsyncMock() mock_members_result.scalars.return_value.all.return_value = all_users # Expenses: A paid 90, split equally (A, B, C -> 30 each) # A: paid 90, share 30 -> +60 # B: paid 0, share 30 -> -30 # C: paid 0, share 30 -> -30 expense1 = ExpenseModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_a.id, total_amount=Decimal("90.00")) expense1.splits = [ ExpenseSplitModel(expense_id=1, user_id=user_a.id, owed_amount=Decimal("30.00")), ExpenseSplitModel(expense_id=1, user_id=user_b.id, owed_amount=Decimal("30.00")), ExpenseSplitModel(expense_id=1, user_id=user_c.id, owed_amount=Decimal("30.00")), ] mock_expenses_result = AsyncMock() mock_expenses_result.scalars.return_value.all.return_value = [expense1] # Settlements: B paid A 10 # A: received 10 -> +60 + 10 = +70 # B: paid 10 -> -30 - 10 = -40 # C: no settlement -> -30 settlement1 = SettlementModel(id=1, group_id=group_model_with_members.id, paid_by_user_id=user_b.id, paid_to_user_id=user_a.id, amount=Decimal("10.00")) mock_settlements_result = AsyncMock() mock_settlements_result.scalars.return_value.all.return_value = [settlement1] mock_db_session.execute.side_effect = [mock_members_result, mock_expenses_result, mock_settlements_result] with patch('app.crud.cost.logger') as mock_settlement_logger: # Patch logger for calculate_suggested_settlements summary = await calculate_group_balance_summary(mock_db_session, group_model_with_members.id) balances_map = {ub.user_id: ub for ub in summary.user_balances} assert balances_map[user_a.id].net_balance == Decimal("70.00") assert balances_map[user_b.id].net_balance == Decimal("-40.00") assert balances_map[user_c.id].net_balance == Decimal("-30.00") # Suggested settlements: B owes A 30 (40-10), C owes A 30 # Net balances: A: +70, B: -40, C: -30 # Algorithm should suggest: B -> A (40), C -> A (30) assert len(summary.suggested_settlements) == 2 settlement_map = {(s.from_user_id, s.to_user_id): s.amount for s in summary.suggested_settlements} # Order might vary, check presence and amounts assert (user_b.id, user_a.id) in settlement_map assert settlement_map[(user_b.id, user_a.id)] == Decimal("40.00") assert (user_c.id, user_a.id) in settlement_map assert settlement_map[(user_c.id, user_a.id)] == Decimal("30.00") mock_settlement_logger.warning.assert_not_called() @pytest.mark.asyncio async def test_calculate_group_balance_summary_group_not_found(mock_db_session): mock_db_session.get.return_value = None with pytest.raises(GroupNotFoundError): await calculate_group_balance_summary(mock_db_session, 999) # TODO: # - Test calculate_list_cost_summary with item added by user not in group/not creator. # - Test calculate_list_cost_summary with remainder distribution for equal share. # - Test calculate_suggested_settlements with zero balances, one debtor, one creditor, etc. # - Test calculate_suggested_settlements with logger warning for remaining balances. # - Test calculate_group_balance_summary with more complex expense/settlement scenarios. # - Test calculate_group_balance_summary when a user in splits/settlements is not in group_members (should ideally not happen if data is clean).