import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.future import select from decimal import Decimal, ROUND_HALF_UP from datetime import datetime, timezone from typing import List as PyList from app.crud.settlement import ( create_settlement, get_settlement_by_id, get_settlements_for_group, get_settlements_involving_user, update_settlement, delete_settlement ) from app.schemas.expense import SettlementCreate, SettlementUpdate from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError, ConflictError # Fixtures @pytest.fixture def mock_db_session(): session = AsyncMock() session.begin = AsyncMock() session.begin_nested = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() session.add = MagicMock() session.delete = MagicMock() session.execute = AsyncMock() session.get = AsyncMock() session.flush = AsyncMock() session.in_transaction = MagicMock(return_value=False) return session @pytest.fixture def settlement_create_data(): return SettlementCreate( group_id=1, paid_by_user_id=1, paid_to_user_id=2, amount=Decimal("10.50"), settlement_date=datetime.now(timezone.utc), description="Test settlement" ) @pytest.fixture def settlement_update_data(): return SettlementUpdate( description="Updated settlement description", settlement_date=datetime.now(timezone.utc), version=1 # Assuming version is required for update ) @pytest.fixture def db_settlement_model(): return SettlementModel( id=1, group_id=1, paid_by_user_id=1, paid_to_user_id=2, amount=Decimal("10.50"), settlement_date=datetime.now(timezone.utc), description="Original settlement", created_by_user_id=1, version=1, # Initial version created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), payer=UserModel(id=1, name="Payer User"), payee=UserModel(id=2, name="Payee User"), group=GroupModel(id=1, name="Test Group"), created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity ) @pytest.fixture def payer_user_model(): return UserModel(id=1, name="Payer User", email="payer@example.com") @pytest.fixture def payee_user_model(): return UserModel(id=2, name="Payee User", email="payee@example.com") @pytest.fixture def group_model(): return GroupModel(id=1, name="Test Group") # Tests for create_settlement @pytest.mark.asyncio async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = SettlementModel( id=1, group_id=settlement_create_data.group_id, paid_by_user_id=settlement_create_data.paid_by_user_id, paid_to_user_id=settlement_create_data.paid_to_user_id, amount=settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), settlement_date=settlement_create_data.settlement_date, description=settlement_create_data.description, created_by_user_id=1, version=1, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc) ) mock_db_session.execute.return_value = mock_result created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1) mock_db_session.add.assert_called_once() mock_db_session.flush.assert_called_once() assert created_settlement is not None assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id @pytest.mark.asyncio async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data): mock_db_session.get.side_effect = [None, payee_user_model, group_model] with pytest.raises(UserNotFoundError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payer" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model): mock_db_session.get.side_effect = [payer_user_model, None, group_model] with pytest.raises(UserNotFoundError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payee" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None] with pytest.raises(GroupNotFoundError): await create_settlement(mock_db_session, settlement_create_data, 1) @pytest.mark.asyncio async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model): settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model] with pytest.raises(InvalidOperationError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Payer and Payee cannot be the same user" in str(excinfo.value) @pytest.mark.asyncio async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model): mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model] mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await create_settlement(mock_db_session, settlement_create_data, 1) assert "Failed to save settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once() # Tests for get_settlement_by_id @pytest.mark.asyncio async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = db_settlement_model mock_db_session.execute.return_value = mock_result settlement = await get_settlement_by_id(mock_db_session, 1) assert settlement is not None assert settlement.id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_settlement_by_id_not_found(mock_db_session): mock_result = AsyncMock() mock_result.scalars.return_value.first.return_value = None mock_db_session.execute.return_value = mock_result settlement = await get_settlement_by_id(mock_db_session, 999) assert settlement is None # Tests for get_settlements_for_group @pytest.mark.asyncio async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_settlement_model] mock_db_session.execute.return_value = mock_result settlements = await get_settlements_for_group(mock_db_session, group_id=1) assert len(settlements) == 1 assert settlements[0].group_id == 1 mock_db_session.execute.assert_called_once() # Tests for get_settlements_involving_user @pytest.mark.asyncio async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_settlement_model] mock_db_session.execute.return_value = mock_result settlements = await get_settlements_involving_user(mock_db_session, user_id=1) assert len(settlements) == 1 assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1 mock_db_session.execute.assert_called_once() @pytest.mark.asyncio async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model): mock_result = AsyncMock() mock_result.scalars.return_value.all.return_value = [db_settlement_model] mock_db_session.execute.return_value = mock_result settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1) assert len(settlements) == 1 mock_db_session.execute.assert_called_once() # Tests for update_settlement @pytest.mark.asyncio async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data): settlement_update_data.version = db_settlement_model.version mock_result = AsyncMock() mock_result.scalar_one_or_none.return_value = db_settlement_model mock_db_session.execute.return_value = mock_result updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) mock_db_session.add.assert_called_once_with(db_settlement_model) mock_db_session.flush.assert_called_once() assert updated_settlement.description == settlement_update_data.description assert updated_settlement.settlement_date == settlement_update_data.settlement_date assert updated_settlement.version == db_settlement_model.version + 1 @pytest.mark.asyncio async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data): settlement_update_data.version = db_settlement_model.version + 1 with pytest.raises(ConflictError): await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) mock_db_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model): # Create an update payload with a field not allowed to be updated, e.g., 'amount' invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version) with pytest.raises(InvalidOperationError) as excinfo: await update_settlement(mock_db_session, db_settlement_model, invalid_update_data) assert "Field 'amount' cannot be updated" in str(excinfo.value) @pytest.mark.asyncio async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data): settlement_update_data.version = db_settlement_model.version mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await update_settlement(mock_db_session, db_settlement_model, settlement_update_data) assert "Failed to update settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once() # Tests for delete_settlement @pytest.mark.asyncio async def test_delete_settlement_success(mock_db_session, db_settlement_model): await delete_settlement(mock_db_session, db_settlement_model) mock_db_session.delete.assert_called_once_with(db_settlement_model) mock_db_session.commit.assert_called_once() @pytest.mark.asyncio async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model): await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version) mock_db_session.delete.assert_called_once_with(db_settlement_model) mock_db_session.commit.assert_called_once() @pytest.mark.asyncio async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model): db_settlement_model.version = 2 with pytest.raises(ConflictError): await delete_settlement(mock_db_session, db_settlement_model, expected_version=1) mock_db_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model): mock_db_session.commit.side_effect = Exception("DB commit failed") with pytest.raises(InvalidOperationError) as excinfo: await delete_settlement(mock_db_session, db_settlement_model) assert "Failed to delete settlement" in str(excinfo.value) mock_db_session.rollback.assert_called_once()