import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timezone from typing import List as PyList, Optional from app.crud.expense import ( create_expense, get_expense_by_id, get_expenses_for_list, get_expenses_for_group, update_expense, delete_expense, get_users_for_splitting ) from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead from app.models import ( Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, User as UserModel, List as ListModel, Group as GroupModel, UserGroup as UserGroupModel, Item as ItemModel, SplitTypeEnum ) from app.core.exceptions import ( ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError, ExpenseNotFoundError, DatabaseTransactionError, ConflictError ) # General Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.begin_nested = AsyncMock() # For nested transactions within functions session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() session.flush = AsyncMock() session.in_transaction = MagicMock(return_value=False) # Mock session.begin() to return an async context manager mock_transaction_context = AsyncMock() session.begin = MagicMock(return_value=mock_transaction_context) return session @pytest.fixture def basic_user_model(): return UserModel(id=1, name="Test User", email="test@example.com", version=1) @pytest.fixture def another_user_model(): return UserModel(id=2, name="Another User", email="another@example.com", version=1) @pytest.fixture def basic_group_model(basic_user_model, another_user_model): group = GroupModel(id=1, name="Test Group", version=1) # Simulate member_associations for get_users_for_splitting if needed directly # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)] return group @pytest.fixture def basic_list_model(basic_group_model, basic_user_model): return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1) @pytest.fixture def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model): return ExpenseCreate( description="Grocery run", total_amount=Decimal("30.00"), currency="USD", expense_date=datetime.now(timezone.utc).date(), split_type=SplitTypeEnum.EQUAL, list_id=basic_list_model.id, group_id=None, # Derived from list item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model): return ExpenseCreate( description="Movies", total_amount=Decimal("50.00"), currency="USD", expense_date=datetime.now(timezone.utc).date(), split_type=SplitTypeEnum.EQUAL, list_id=None, group_id=basic_group_model.id, item_id=None, paid_by_user_id=basic_user_model.id, splits_in=None ) @pytest.fixture def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model): return ExpenseCreate( description="Dinner", total_amount=Decimal("100.00"), expense_date=datetime.now(timezone.utc).date(), currency="USD", split_type=SplitTypeEnum.EXACT_AMOUNTS, group_id=basic_group_model.id, paid_by_user_id=basic_user_model.id, splits_in=[ ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")), ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")), ] ) @pytest.fixture def expense_update_data(): return ExpenseUpdate( description="Updated Dinner", total_amount=Decimal("120.00"), version=1 # Ensure version is provided for updates ) @pytest.fixture def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model): expense = ExpenseModel( id=1, description=expense_create_data_equal_split_group_ctx.description, total_amount=expense_create_data_equal_split_group_ctx.total_amount, currency=expense_create_data_equal_split_group_ctx.currency, expense_date=expense_create_data_equal_split_group_ctx.expense_date, split_type=expense_create_data_equal_split_group_ctx.split_type, list_id=expense_create_data_equal_split_group_ctx.list_id, group_id=expense_create_data_equal_split_group_ctx.group_id, group=basic_group_model, # Link to group fixture item_id=expense_create_data_equal_split_group_ctx.item_id, paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id, created_by_user_id=basic_user_model.id, paid_by=basic_user_model, created_by_user=basic_user_model, version=1, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) # Simulate splits for an existing expense expense.splits = [ ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1), ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2 ] return expense # Tests for get_users_for_splitting @pytest.mark.asyncio async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model): user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id) user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id) basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2] mock_db_session.get.return_value = basic_group_model # Mock get for group users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id) assert len(users) == 2 assert basic_user_model in users assert another_user_model in users @pytest.mark.asyncio async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model): user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id) user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id) basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2] basic_list_model.group = basic_group_model # Ensure list is associated with the group mock_db_session.get.return_value = basic_list_model # Mock get for list users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id) assert len(users) == 2 assert basic_user_model in users assert another_user_model in users # --- create_expense Tests --- @pytest.mark.asyncio async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model): # Setup mocks mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group # Mock get_users_for_splitting directly with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users: mock_get_users.return_value = [basic_user_model, another_user_model] async def mock_refresh(instance, attribute_names=None, with_for_update=None): if isinstance(instance, ExpenseModel): instance.id = 1 # Simulate ID assignment after flush instance.version = 1 instance.created_at = datetime.now(timezone.utc) instance.updated_at = datetime.now(timezone.utc) # Simulate splits being added to the session and linked by refresh instance.splits = [ ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1), ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1) ] return None mock_db_session.refresh.side_effect = mock_refresh created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once() assert created_expense is not None assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount assert created_expense.split_type == SplitTypeEnum.EQUAL assert len(created_expense.splits) == 2 expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) for split in created_expense.splits: assert split.owed_amount == expected_amount_per_user @pytest.mark.asyncio async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model): mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split async def mock_refresh(instance, attribute_names=None, with_for_update=None): if isinstance(instance, ExpenseModel): instance.id = 2 instance.version = 1 instance.splits = [ ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")), ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00")) ] return None mock_db_session.refresh.side_effect = mock_refresh created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id) mock_db_session.add.assert_called() mock_db_session.flush.assert_called_once() assert created_expense is not None assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS assert len(created_expense.splits) == 2 assert created_expense.splits[0].owed_amount == Decimal("60.00") assert created_expense.splits[1].owed_amount == Decimal("40.00") @pytest.mark.asyncio async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx): mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen with pytest.raises(UserNotFoundError): await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema @pytest.mark.asyncio async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model): mock_db_session.get.return_value = basic_user_model # Payer found expense_data = expense_create_data_equal_split_group_ctx.model_copy() expense_data.list_id = None expense_data.group_id = None with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"): await create_expense(mock_db_session, expense_data, basic_user_model.id) # --- get_expense_by_id Tests --- @pytest.mark.asyncio async def test_get_expense_by_id_found(mock_db_session, db_expense_model): mock_db_session.get.return_value = db_expense_model expense = await get_expense_by_id(mock_db_session, db_expense_model.id) assert expense is not None assert expense.id == db_expense_model.id mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[ MagicMock(), MagicMock(), MagicMock() ]) # Adjust based on actual options used in get_expense_by_id @pytest.mark.asyncio async def test_get_expense_by_id_not_found(mock_db_session): mock_db_session.get.return_value = None expense = await get_expense_by_id(mock_db_session, 999) assert expense is None # --- get_expenses_for_list Tests --- @pytest.mark.asyncio async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- get_expenses_for_group Tests --- @pytest.mark.asyncio async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_expense_model] mock_db_session.execute.return_value = mock_result expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id) assert len(expenses) == 1 assert expenses[0].id == db_expense_model.id mock_db_session.execute.assert_called_once() # --- update_expense Tests --- @pytest.mark.asyncio async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model): expense_update_data.version = db_expense_model.version # Match version # Simulate that the db_expense_model is returned by session.get mock_db_session.get.return_value = db_expense_model updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id) mock_db_session.add.assert_called_once_with(db_expense_model) mock_db_session.flush.assert_called_once() mock_db_session.refresh.assert_called_once_with(db_expense_model) assert updated_expense.description == expense_update_data.description assert updated_expense.total_amount == expense_update_data.total_amount assert updated_expense.version == db_expense_model.version # Version incremented by the function @pytest.mark.asyncio async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model): mock_db_session.get.return_value = None # Expense not found with pytest.raises(ExpenseNotFoundError): await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id) @pytest.mark.asyncio async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model): expense_update_data.version = db_expense_model.version + 1 # Create version mismatch mock_db_session.get.return_value = db_expense_model with pytest.raises(ConflictError): await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id) mock_db_session.rollback.assert_called_once() # --- delete_expense Tests --- @pytest.mark.asyncio async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model): mock_db_session.get.return_value = db_expense_model # Simulate expense found await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id) mock_db_session.delete.assert_called_once_with(db_expense_model) # Assuming delete_expense uses session.begin() and commits mock_db_session.begin().commit.assert_called_once() @pytest.mark.asyncio async def test_delete_expense_not_found(mock_db_session, basic_user_model): mock_db_session.get.return_value = None # Expense not found with pytest.raises(ExpenseNotFoundError): await delete_expense(mock_db_session, 999, basic_user_model.id) mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit @pytest.mark.asyncio async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model): mock_db_session.get.return_value = db_expense_model mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig") with pytest.raises(DatabaseTransactionError): await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id) mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context