import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timezone from typing import List as PyList, Optional from app.crud.expense import ( create_expense, get_expense_by_id, get_expenses_for_list, get_expenses_for_group, update_expense, # Assuming update_expense exists delete_expense, # Assuming delete_expense exists get_users_for_splitting # Helper, might test indirectly ) from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate from app.models import ( Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, User as UserModel, List as ListModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel, SplitTypeEnum ) from app.core.exceptions import ( ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError ) # General Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() session.flush = AsyncMock() # create_expense uses flush return session @pytest.fixture def basic_user_model(): return UserModel(id=1, name="Test User", email="test@example.com") @pytest.fixture def another_user_model(): return UserModel(id=2, name="Another User", email="another@example.com") @pytest.fixture def basic_group_model(): group = GroupModel(id=1, name="Test Group") # Simulate member_associations for get_users_for_splitting if needed directly # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())] return group @pytest.fixture def basic_list_model(basic_group_model, basic_user_model): return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model) @pytest.fixture def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model): return ExpenseCreate( description="Grocery run", total_amount=Decimal("30.00"), currency="USD", expense_date=datetime.now(timezone.utc), split_type=SplitTypeEnum.EQUAL, list_id=basic_list_model.id, group_id=None, # Derived from list item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model): return ExpenseCreate( description="Movies", total_amount=Decimal("50.00"), currency="USD", expense_date=datetime.now(timezone.utc), split_type=SplitTypeEnum.EQUAL, list_id=None, group_id=basic_group_model.id, item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model): return ExpenseCreate( description="Dinner", total_amount=Decimal("100.00"), split_type=SplitTypeEnum.EXACT_AMOUNTS, group_id=basic_group_model.id, paid_by_user_id=basic_user_model.id, splits_in=[ ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")), ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")), ] ) @pytest.fixture def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model): return ExpenseModel( id=1, description=expense_create_data_equal_split_group_ctx.description, total_amount=expense_create_data_equal_split_group_ctx.total_amount, currency=expense_create_data_equal_split_group_ctx.currency, expense_date=expense_create_data_equal_split_group_ctx.expense_date, split_type=expense_create_data_equal_split_group_ctx.split_type, list_id=expense_create_data_equal_split_group_ctx.list_id, group_id=expense_create_data_equal_split_group_ctx.group_id, item_id=expense_create_data_equal_split_group_ctx.item_id, paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, created_by_user_id=basic_user_model.id, paid_by=basic_user_model, # Assuming paid_by relation is loaded created_by_user=basic_user_model, # Assuming created_by_user relation is loaded # splits would be populated after creation usually version=1 ) # Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed) @pytest.mark.asyncio async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model): # Setup group with members user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id) user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id) basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2] mock_execute = AsyncMock() mock_execute.scalars.return_value.first.return_value = basic_group_model mock_db_session.execute.return_value = mock_execute users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1) assert len(users) == 2 assert basic_user_model in users assert another_user_model in users # --- create_expense Tests --- @pytest.mark.asyncio async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model): mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group # Mock get_users_for_splitting call within create_expense # This is a bit tricky as it's an internal call. Patching is an option. with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users: mock_get_users.return_value = [basic_user_model, another_user_model] created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() # mock_db_session.commit.assert_called_once() # create_expense does not commit itself # mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself assert created_expense is not None assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount assert created_expense.split_type == SplitTypeEnum.EQUAL assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance # Check split amounts expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) for split in created_expense.splits: assert split.owed_amount == expected_amount_per_user @pytest.mark.asyncio async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model): mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group # Mock the select for user validation in exact splits mock_user_select_result = AsyncMock() mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples # To make it behave like scalars().all() that returns a list of IDs: # We need to mock the scalars().all() part, or the whole execute chain for user validation. # A simpler way for this specific case might be to mock the select for User.id mock_execute_user_ids = AsyncMock() # Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process # It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}` # Let's assume the select returns a list of Row objects or tuples with one element mock_user_ids_result_proxy = MagicMock() mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)]) mock_db_session.execute.return_value = mock_user_ids_result_proxy created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() assert created_expense is not None assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS assert len(created_expense.splits) == 2 assert created_expense.splits[0].owed_amount == Decimal("60.00") assert created_expense.splits[1].owed_amount == Decimal("40.00") @pytest.mark.asyncio async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx): mock_db_session.get.return_value = None # Payer not found with pytest.raises(UserNotFoundError): await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1) @pytest.mark.asyncio async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model): mock_db_session.get.return_value = basic_user_model # Payer found expense_data = expense_create_data_equal_split_group_ctx.model_copy() expense_data.list_id = None expense_data.group_id = None with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"): await create_expense(mock_db_session, expense_data, 1) # --- get_expense_by_id Tests --- @pytest.mark.asyncio async def test_get_expense_by_id_found(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_expense_model mock_db_session.execute.return_value = mock_result expense = await get_expense_by_id(mock_db_session, 1) assert expense is not None assert expense.id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_expense_by_id_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result expense = await get_expense_by_id(mock_db_session, 999) assert expense is None # --- get_expenses_for_list Tests --- @pytest.mark.asyncio async def test_get_expenses_for_list_success(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_list(mock_db_session, list_id=1) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- get_expenses_for_group Tests --- @pytest.mark.asyncio async def test_get_expenses_for_group_success(mock_db_session, db_expense_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_group(mock_db_session, group_id=1) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- Stubs for update_expense and delete_expense --- # These will need more details once the actual implementation of update/delete is clear # For example, how splits are handled on update, versioning, etc. @pytest.mark.asyncio async def test_update_expense_stub(mock_db_session): # Placeholder: Test logic for update_expense will be more complex # Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh # Also depends on what fields are updatable and how splits are managed. expense_to_update = MagicMock(spec=ExpenseModel) expense_to_update.version = 1 update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition # Simulate the update_expense function behavior # For example, if it loads the expense, modifies, commits, refreshes: # mock_db_session.get.return_value = expense_to_update # updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload) # assert updated_expense.description == "New description" # mock_db_session.commit.assert_called_once() # mock_db_session.refresh.assert_called_once() pass # Replace with actual test logic @pytest.mark.asyncio async def test_delete_expense_stub(mock_db_session): # Placeholder: Test logic for delete_expense # Needs an existing expense object and mocking of delete/commit # Also, consider implications (e.g., are splits deleted?) expense_to_delete = MagicMock(spec=ExpenseModel) expense_to_delete.id = 1 expense_to_delete.version = 1 # Simulate delete_expense behavior # mock_db_session.get.return_value = expense_to_delete # If it re-fetches # await delete_expense(mock_db_session, expense_to_delete, expected_version=1) # mock_db_session.delete.assert_called_once_with(expense_to_delete) # mock_db_session.commit.assert_called_once() pass # Replace with actual test logic # TODO: Add more tests for create_expense covering: # - List context success # - Percentage, Shares, Item-based splits # - Error cases for each split type (e.g., total mismatch, invalid inputs) # - Validation of list_id/group_id consistency # - User not found in splits_in # - Item not found for ITEM_BASED split # TODO: Flesh out update_expense tests: # - Success case # - Version mismatch # - Trying to update immutable fields # - How splits are handled (recalculated, deleted/recreated, or not changeable) # TODO: Flesh out delete_expense tests: # - Success case # - Version mismatch (if applicable) # - Ensure associated splits are also deleted (cascade behavior)