277 lines
12 KiB
Python
277 lines
12 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
from sqlalchemy.future import select
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from datetime import datetime, timezone
|
|
from typing import List as PyList
|
|
|
|
from app.crud.settlement import (
|
|
create_settlement,
|
|
get_settlement_by_id,
|
|
get_settlements_for_group,
|
|
get_settlements_involving_user,
|
|
update_settlement,
|
|
delete_settlement
|
|
)
|
|
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
|
from app.models import Settlement as SettlementModel, User as UserModel, Group as GroupModel
|
|
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError, ConflictError
|
|
|
|
# Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.begin = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.delete = MagicMock()
|
|
session.execute = AsyncMock()
|
|
session.get = AsyncMock()
|
|
session.flush = AsyncMock()
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def settlement_create_data():
|
|
return SettlementCreate(
|
|
group_id=1,
|
|
paid_by_user_id=1,
|
|
paid_to_user_id=2,
|
|
amount=Decimal("10.50"),
|
|
settlement_date=datetime.now(timezone.utc),
|
|
description="Test settlement"
|
|
)
|
|
|
|
@pytest.fixture
|
|
def settlement_update_data():
|
|
return SettlementUpdate(
|
|
description="Updated settlement description",
|
|
settlement_date=datetime.now(timezone.utc),
|
|
version=1 # Assuming version is required for update
|
|
)
|
|
|
|
@pytest.fixture
|
|
def db_settlement_model():
|
|
return SettlementModel(
|
|
id=1,
|
|
group_id=1,
|
|
paid_by_user_id=1,
|
|
paid_to_user_id=2,
|
|
amount=Decimal("10.50"),
|
|
settlement_date=datetime.now(timezone.utc),
|
|
description="Original settlement",
|
|
created_by_user_id=1,
|
|
version=1, # Initial version
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
payer=UserModel(id=1, name="Payer User"),
|
|
payee=UserModel(id=2, name="Payee User"),
|
|
group=GroupModel(id=1, name="Test Group"),
|
|
created_by_user=UserModel(id=1, name="Payer User") # Same as payer for simplicity
|
|
)
|
|
|
|
@pytest.fixture
|
|
def payer_user_model():
|
|
return UserModel(id=1, name="Payer User", email="payer@example.com")
|
|
|
|
@pytest.fixture
|
|
def payee_user_model():
|
|
return UserModel(id=2, name="Payee User", email="payee@example.com")
|
|
|
|
@pytest.fixture
|
|
def group_model():
|
|
return GroupModel(id=1, name="Test Group")
|
|
|
|
# Tests for create_settlement
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_success(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
|
|
mock_result = AsyncMock()
|
|
mock_result.scalar_one_or_none.return_value = SettlementModel(
|
|
id=1,
|
|
group_id=settlement_create_data.group_id,
|
|
paid_by_user_id=settlement_create_data.paid_by_user_id,
|
|
paid_to_user_id=settlement_create_data.paid_to_user_id,
|
|
amount=settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
settlement_date=settlement_create_data.settlement_date,
|
|
description=settlement_create_data.description,
|
|
created_by_user_id=1,
|
|
version=1,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc)
|
|
)
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
created_settlement = await create_settlement(mock_db_session, settlement_create_data, current_user_id=1)
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.flush.assert_called_once()
|
|
assert created_settlement is not None
|
|
assert created_settlement.amount == settlement_create_data.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
assert created_settlement.paid_by_user_id == settlement_create_data.paid_by_user_id
|
|
assert created_settlement.paid_to_user_id == settlement_create_data.paid_to_user_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_payer_not_found(mock_db_session, settlement_create_data):
|
|
mock_db_session.get.side_effect = [None, payee_user_model, group_model]
|
|
with pytest.raises(UserNotFoundError) as excinfo:
|
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
|
assert "Payer" in str(excinfo.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_payee_not_found(mock_db_session, settlement_create_data, payer_user_model):
|
|
mock_db_session.get.side_effect = [payer_user_model, None, group_model]
|
|
with pytest.raises(UserNotFoundError) as excinfo:
|
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
|
assert "Payee" in str(excinfo.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_group_not_found(mock_db_session, settlement_create_data, payer_user_model, payee_user_model):
|
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, None]
|
|
with pytest.raises(GroupNotFoundError):
|
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_payer_equals_payee(mock_db_session, settlement_create_data, payer_user_model, group_model):
|
|
settlement_create_data.paid_to_user_id = settlement_create_data.paid_by_user_id
|
|
mock_db_session.get.side_effect = [payer_user_model, payer_user_model, group_model]
|
|
with pytest.raises(InvalidOperationError) as excinfo:
|
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
|
assert "Payer and Payee cannot be the same user" in str(excinfo.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_settlement_commit_failure(mock_db_session, settlement_create_data, payer_user_model, payee_user_model, group_model):
|
|
mock_db_session.get.side_effect = [payer_user_model, payee_user_model, group_model]
|
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
|
with pytest.raises(InvalidOperationError) as excinfo:
|
|
await create_settlement(mock_db_session, settlement_create_data, 1)
|
|
assert "Failed to save settlement" in str(excinfo.value)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
|
|
# Tests for get_settlement_by_id
|
|
@pytest.mark.asyncio
|
|
async def test_get_settlement_by_id_found(mock_db_session, db_settlement_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = db_settlement_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
settlement = await get_settlement_by_id(mock_db_session, 1)
|
|
assert settlement is not None
|
|
assert settlement.id == 1
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_settlement_by_id_not_found(mock_db_session):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
settlement = await get_settlement_by_id(mock_db_session, 999)
|
|
assert settlement is None
|
|
|
|
# Tests for get_settlements_for_group
|
|
@pytest.mark.asyncio
|
|
async def test_get_settlements_for_group_success(mock_db_session, db_settlement_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
settlements = await get_settlements_for_group(mock_db_session, group_id=1)
|
|
assert len(settlements) == 1
|
|
assert settlements[0].group_id == 1
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# Tests for get_settlements_involving_user
|
|
@pytest.mark.asyncio
|
|
async def test_get_settlements_involving_user_success(mock_db_session, db_settlement_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
settlements = await get_settlements_involving_user(mock_db_session, user_id=1)
|
|
assert len(settlements) == 1
|
|
assert settlements[0].paid_by_user_id == 1 or settlements[0].paid_to_user_id == 1
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_settlements_involving_user_with_group_filter(mock_db_session, db_settlement_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_settlement_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
settlements = await get_settlements_involving_user(mock_db_session, user_id=1, group_id=1)
|
|
assert len(settlements) == 1
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
|
|
# Tests for update_settlement
|
|
@pytest.mark.asyncio
|
|
async def test_update_settlement_success(mock_db_session, db_settlement_model, settlement_update_data):
|
|
settlement_update_data.version = db_settlement_model.version
|
|
|
|
mock_result = AsyncMock()
|
|
mock_result.scalar_one_or_none.return_value = db_settlement_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
updated_settlement = await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
|
mock_db_session.add.assert_called_once_with(db_settlement_model)
|
|
mock_db_session.flush.assert_called_once()
|
|
assert updated_settlement.description == settlement_update_data.description
|
|
assert updated_settlement.settlement_date == settlement_update_data.settlement_date
|
|
assert updated_settlement.version == db_settlement_model.version + 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_settlement_version_mismatch(mock_db_session, db_settlement_model, settlement_update_data):
|
|
settlement_update_data.version = db_settlement_model.version + 1
|
|
with pytest.raises(ConflictError):
|
|
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_settlement_disallowed_field(mock_db_session, db_settlement_model):
|
|
# Create an update payload with a field not allowed to be updated, e.g., 'amount'
|
|
invalid_update_data = SettlementUpdate(amount=Decimal("100.00"), version=db_settlement_model.version)
|
|
with pytest.raises(InvalidOperationError) as excinfo:
|
|
await update_settlement(mock_db_session, db_settlement_model, invalid_update_data)
|
|
assert "Field 'amount' cannot be updated" in str(excinfo.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_settlement_commit_failure(mock_db_session, db_settlement_model, settlement_update_data):
|
|
settlement_update_data.version = db_settlement_model.version
|
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
|
with pytest.raises(InvalidOperationError) as excinfo:
|
|
await update_settlement(mock_db_session, db_settlement_model, settlement_update_data)
|
|
assert "Failed to update settlement" in str(excinfo.value)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
# Tests for delete_settlement
|
|
@pytest.mark.asyncio
|
|
async def test_delete_settlement_success(mock_db_session, db_settlement_model):
|
|
await delete_settlement(mock_db_session, db_settlement_model)
|
|
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_settlement_success_with_version_check(mock_db_session, db_settlement_model):
|
|
await delete_settlement(mock_db_session, db_settlement_model, expected_version=db_settlement_model.version)
|
|
mock_db_session.delete.assert_called_once_with(db_settlement_model)
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_settlement_version_mismatch(mock_db_session, db_settlement_model):
|
|
db_settlement_model.version = 2
|
|
with pytest.raises(ConflictError):
|
|
await delete_settlement(mock_db_session, db_settlement_model, expected_version=1)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_settlement_commit_failure(mock_db_session, db_settlement_model):
|
|
mock_db_session.commit.side_effect = Exception("DB commit failed")
|
|
with pytest.raises(InvalidOperationError) as excinfo:
|
|
await delete_settlement(mock_db_session, db_settlement_model)
|
|
assert "Failed to delete settlement" in str(excinfo.value)
|
|
mock_db_session.rollback.assert_called_once() |