mitlist/be/tests/crud/test_expense.py
google-labs-jules[bot] f1152c5745 feat: Implement traceable expense splitting and settlement activities
Backend:
- Added `SettlementActivity` model to track payments against specific expense shares.
- Added `status` and `paid_at` to `ExpenseSplit` model.
- Added `overall_settlement_status` to `Expense` model.
- Implemented CRUD for `SettlementActivity`, including logic to update parent expense/split statuses.
- Updated `Expense` CRUD to initialize new status fields.
- Defined Pydantic schemas for `SettlementActivity` and updated `Expense/ExpenseSplit` schemas.
- Exposed API endpoints for creating/listing settlement activities and settling shares.
- Adjusted group balance summary logic to include settlement activities.
- Added comprehensive backend unit and API tests for new functionality.

Frontend (Foundation & TODOs due to my current capabilities):
- Created TypeScript interfaces for all new/updated models.
- Set up `listDetailStore.ts` with an action to handle `settleExpenseSplit` (API call is a placeholder) and refresh data.
- Created `SettleShareModal.vue` component for payment confirmation.
- Added unit tests for the new modal and store logic.
- Updated `ListDetailPage.vue` to display detailed expense/share statuses and settlement activities.
- `mitlist_doc.md` updated to reflect all backend changes and current frontend status.
- A `TODO.md` (implicitly within `mitlist_doc.md`'s new section) outlines necessary manual frontend integrations for `api.ts` and `ListDetailPage.vue` to complete the 'Settle Share' UI flow.

This set of changes provides the core backend infrastructure for precise expense share tracking and settlement, and lays the groundwork for full frontend integration.
2025-05-22 07:05:31 +00:00

369 lines
17 KiB
Python

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional
from app.crud.expense import (
create_expense,
get_expense_by_id,
get_expenses_for_list,
get_expenses_for_group,
update_expense,
delete_expense,
get_users_for_splitting
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
Item as ItemModel,
SplitTypeEnum,
ExpenseOverallStatusEnum, # Added
ExpenseSplitStatusEnum # Added
)
from app.core.exceptions import (
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError,
ExpenseNotFoundError,
DatabaseTransactionError,
ConflictError
)
# General Fixtures
@pytest.fixture
def mock_db_session():
session = AsyncMock()
session.begin_nested = AsyncMock() # For nested transactions within functions
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
session.execute = AsyncMock()
session.get = AsyncMock()
session.flush = AsyncMock()
session.in_transaction = MagicMock(return_value=False)
# Mock session.begin() to return an async context manager
mock_transaction_context = AsyncMock()
session.begin = MagicMock(return_value=mock_transaction_context)
return session
@pytest.fixture
def basic_user_model():
return UserModel(id=1, name="Test User", email="test@example.com", version=1)
@pytest.fixture
def another_user_model():
return UserModel(id=2, name="Another User", email="another@example.com", version=1)
@pytest.fixture
def basic_group_model(basic_user_model, another_user_model):
group = GroupModel(id=1, name="Test Group", version=1)
# Simulate member_associations for get_users_for_splitting if needed directly
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)]
return group
@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1)
@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
return ExpenseCreate(
description="Grocery run",
total_amount=Decimal("30.00"),
currency="USD",
expense_date=datetime.now(timezone.utc).date(),
split_type=SplitTypeEnum.EQUAL,
list_id=basic_list_model.id,
group_id=None, # Derived from list
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
return ExpenseCreate(
description="Movies",
total_amount=Decimal("50.00"),
currency="USD",
expense_date=datetime.now(timezone.utc).date(),
split_type=SplitTypeEnum.EQUAL,
list_id=None,
group_id=basic_group_model.id,
item_id=None,
paid_by_user_id=basic_user_model.id,
splits_in=None
)
@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
return ExpenseCreate(
description="Dinner",
total_amount=Decimal("100.00"),
expense_date=datetime.now(timezone.utc).date(),
currency="USD",
split_type=SplitTypeEnum.EXACT_AMOUNTS,
group_id=basic_group_model.id,
paid_by_user_id=basic_user_model.id,
splits_in=[
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
]
)
@pytest.fixture
def expense_update_data():
return ExpenseUpdate(
description="Updated Dinner",
total_amount=Decimal("120.00"),
version=1 # Ensure version is provided for updates
)
@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model):
expense = ExpenseModel(
id=1,
description=expense_create_data_equal_split_group_ctx.description,
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
currency=expense_create_data_equal_split_group_ctx.currency,
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
split_type=expense_create_data_equal_split_group_ctx.split_type,
list_id=expense_create_data_equal_split_group_ctx.list_id,
group_id=expense_create_data_equal_split_group_ctx.group_id,
group=basic_group_model, # Link to group fixture
item_id=expense_create_data_equal_split_group_ctx.item_id,
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
created_by_user_id=basic_user_model.id,
paid_by=basic_user_model,
created_by_user=basic_user_model,
version=1,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
# Simulate splits for an existing expense
expense.splits = [
ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2
]
return expense
# Tests for get_users_for_splitting
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
mock_db_session.get.return_value = basic_group_model # Mock get for group
users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
@pytest.mark.asyncio
async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model):
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
basic_list_model.group = basic_group_model # Ensure list is associated with the group
mock_db_session.get.return_value = basic_list_model # Mock get for list
users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id)
assert len(users) == 2
assert basic_user_model in users
assert another_user_model in users
# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
# Setup mocks
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group
# Mock get_users_for_splitting directly
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
mock_get_users.return_value = [basic_user_model, another_user_model]
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
if isinstance(instance, ExpenseModel):
instance.id = 1 # Simulate ID assignment after flush
instance.version = 1
instance.created_at = datetime.now(timezone.utc)
instance.updated_at = datetime.now(timezone.utc)
# Simulate splits being added to the session and linked by refresh
instance.splits = [
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1)
]
return None
mock_db_session.refresh.side_effect = mock_refresh
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once()
assert created_expense is not None
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
assert created_expense.split_type == SplitTypeEnum.EQUAL
assert len(created_expense.splits) == 2
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
for split in created_expense.splits:
assert split.owed_amount == expected_amount_per_user
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
if isinstance(instance, ExpenseModel):
instance.id = 2
instance.version = 1
instance.splits = [
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00"))
]
return None
mock_db_session.refresh.side_effect = mock_refresh
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id)
mock_db_session.add.assert_called()
mock_db_session.flush.assert_called_once()
assert created_expense is not None
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
assert len(created_expense.splits) == 2
assert created_expense.splits[0].owed_amount == Decimal("60.00")
assert created_expense.splits[1].owed_amount == Decimal("40.00")
for split in created_expense.splits:
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen
with pytest.raises(UserNotFoundError):
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema
@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
mock_db_session.get.return_value = basic_user_model # Payer found
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
expense_data.list_id = None
expense_data.group_id = None
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
await create_expense(mock_db_session, expense_data, basic_user_model.id)
# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
mock_db_session.get.return_value = db_expense_model
expense = await get_expense_by_id(mock_db_session, db_expense_model.id)
assert expense is not None
assert expense.id == db_expense_model.id
mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[
MagicMock(), MagicMock(), MagicMock()
]) # Adjust based on actual options used in get_expense_by_id
@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
mock_db_session.get.return_value = None
expense = await get_expense_by_id(mock_db_session, 999)
assert expense is None
# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model):
mock_result = AsyncMock()
mock_result.scalars.return_value.all.return_value = [db_expense_model]
mock_db_session.execute.return_value = mock_result
expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id)
assert len(expenses) == 1
assert expenses[0].id == db_expense_model.id
mock_db_session.execute.assert_called_once()
# --- update_expense Tests ---
@pytest.mark.asyncio
async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
expense_update_data.version = db_expense_model.version # Match version
# Simulate that the db_expense_model is returned by session.get
mock_db_session.get.return_value = db_expense_model
updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
mock_db_session.add.assert_called_once_with(db_expense_model)
mock_db_session.flush.assert_called_once()
mock_db_session.refresh.assert_called_once_with(db_expense_model)
assert updated_expense.description == expense_update_data.description
assert updated_expense.total_amount == expense_update_data.total_amount
assert updated_expense.version == db_expense_model.version # Version incremented by the function
@pytest.mark.asyncio
async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model):
mock_db_session.get.return_value = None # Expense not found
with pytest.raises(ExpenseNotFoundError):
await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id)
@pytest.mark.asyncio
async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
expense_update_data.version = db_expense_model.version + 1 # Create version mismatch
mock_db_session.get.return_value = db_expense_model
with pytest.raises(ConflictError):
await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
mock_db_session.rollback.assert_called_once()
# --- delete_expense Tests ---
@pytest.mark.asyncio
async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model):
mock_db_session.get.return_value = db_expense_model # Simulate expense found
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
mock_db_session.delete.assert_called_once_with(db_expense_model)
# Assuming delete_expense uses session.begin() and commits
mock_db_session.begin().commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_expense_not_found(mock_db_session, basic_user_model):
mock_db_session.get.return_value = None # Expense not found
with pytest.raises(ExpenseNotFoundError):
await delete_expense(mock_db_session, 999, basic_user_model.id)
mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit
@pytest.mark.asyncio
async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model):
mock_db_session.get.return_value = db_expense_model
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
with pytest.raises(DatabaseTransactionError):
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context