mitlist/be/tests/crud/test_expense.py

334 lines
14 KiB
Python

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional
from app.crud.expense import (
create_expense,
get_expense_by_id,
get_expenses_for_list,
get_expenses_for_group,
update_expense, # Assuming update_expense exists
delete_expense, # Assuming delete_expense exists
get_users_for_splitting # Helper, might test indirectly
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Item as ItemModel,
SplitTypeEnum
)
from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError
)
# General Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin = AsyncMock()
session.begin_nested = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
return session
@pytest.fixture
def basic_user_model():
return UserModel(id=1, name="Test User", email="test@example.com")
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com")
@pytest.fixture
def basic_group_model():
group = GroupModel(id=1, name="Test Group")
# Simulate member_associations for get_users_for_splitting if needed directly
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
return group
@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)
@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
return ExpenseCreate(
description="Grocery run",
total_amount=Decimal("30.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=basic_list_model.id,
group_id=None, # Derived from list
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
return ExpenseCreate(
description="Movies",
total_amount=Decimal("50.00"),
currency="USD",
expense_date=datetime.now(timezone.utc),
split_type=SplitTypeEnum.EQUAL,
list_id=None,
group_id=basic_group_model.id,
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
return ExpenseCreate(
description="Dinner",
total_amount=Decimal("100.00"),
split_type=SplitTypeEnum.EXACT_AMOUNTS,
group_id=basic_group_model.id,
paid_by_user_id=basic_user_model.id,
splits_in=[
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
]
)
@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
return ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model, # Assuming paid_by relation is loaded
created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
# splits would be populated after creation usually
version=1
)
# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
# Setup group with members
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
mock_execute = AsyncMock()
mock_execute.scalars.return_value.first.return_value = basic_group_model
mock_db_session.execute.return_value = mock_execute
users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
version=1
)
mock_db_session.execute.return_value = mock_result
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model]
mock_result = AsyncMock()
mock_result.scalar_one_or_none.return_value = ExpenseModel(
id=1,
description=expense_create_data_exact_split.description,
total_amount=expense_create_data_exact_split.total_amount,
currency="USD",
expense_date=expense_create_data_exact_split.expense_date,
split_type=expense_create_data_exact_split.split_type,
list_id=expense_create_data_exact_split.list_id,
group_id=expense_create_data_exact_split.group_id,
item_id=expense_create_data_exact_split.item_id,
paid_by_user_id=expense_create_data_exact_split.paid_by_user_id,
created_by_user_id=basic_user_model.id,
version=1
)
mock_db_session.execute.return_value = mock_result
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
mock_db_session.get.return_value = None # Payer not found
with pytest.raises(UserNotFoundError):
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)
@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
mock_db_session.get.return_value = basic_user_model # Payer found
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
expense_data.list_id = None
expense_data.group_id = None
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
await create_expense(mock_db_session, expense_data, 1)
# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = db_expense_model
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 1)
assert expense is not None
assert expense.id == 1
mock_db_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
mock_result = AsyncMock()
mock_result.scalars.return_value.first.return_value = None
mock_db_session.execute.return_value = mock_result
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
assert len(expenses) == 1
assert expenses[0].list_id == 1
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
assert len(expenses) == 1
assert expenses[0].group_id == 1
mock_db_session.execute.assert_called_once()
# --- Stubs for update_expense and delete_expense ---
# These will need more details once the actual implementation of update/delete is clear
# For example, how splits are handled on update, versioning, etc.
@pytest.mark.asyncio
async def test_update_expense_stub(mock_db_session):
# Placeholder: Test logic for update_expense will be more complex
# Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
# Also depends on what fields are updatable and how splits are managed.
expense_to_update = MagicMock(spec=ExpenseModel)
expense_to_update.version = 1
update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
# Simulate the update_expense function behavior
# For example, if it loads the expense, modifies, commits, refreshes:
# mock_db_session.get.return_value = expense_to_update
# updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
# assert updated_expense.description == "New description"
# mock_db_session.commit.assert_called_once()
# mock_db_session.refresh.assert_called_once()
pass # Replace with actual test logic
@pytest.mark.asyncio
async def test_delete_expense_stub(mock_db_session):
# Placeholder: Test logic for delete_expense
# Needs an existing expense object and mocking of delete/commit
# Also, consider implications (e.g., are splits deleted?)
expense_to_delete = MagicMock(spec=ExpenseModel)
expense_to_delete.id = 1
expense_to_delete.version = 1
# Simulate delete_expense behavior
# mock_db_session.get.return_value = expense_to_delete # If it re-fetches
# await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
# mock_db_session.delete.assert_called_once_with(expense_to_delete)
# mock_db_session.commit.assert_called_once()
pass # Replace with actual test logic
# TODO: Add more tests for create_expense covering:
# - List context success
# - Percentage, Shares, Item-based splits
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
# - Validation of list_id/group_id consistency
# - User not found in splits_in
# - Item not found for ITEM_BASED split
# TODO: Flesh out update_expense tests:
# - Success case
# - Version mismatch
# - Trying to update immutable fields
# - How splits are handled (recalculated, deleted/recreated, or not changeable)
# TODO: Flesh out delete_expense tests:
# - Success case
# - Version mismatch (if applicable)
# - Ensure associated splits are also deleted (cascade behavior)