![google-labs-jules[bot]](/assets/img/avatar_default.png)
Backend: - Added `SettlementActivity` model to track payments against specific expense shares. - Added `status` and `paid_at` to `ExpenseSplit` model. - Added `overall_settlement_status` to `Expense` model. - Implemented CRUD for `SettlementActivity`, including logic to update parent expense/split statuses. - Updated `Expense` CRUD to initialize new status fields. - Defined Pydantic schemas for `SettlementActivity` and updated `Expense/ExpenseSplit` schemas. - Exposed API endpoints for creating/listing settlement activities and settling shares. - Adjusted group balance summary logic to include settlement activities. - Added comprehensive backend unit and API tests for new functionality. Frontend (Foundation & TODOs due to my current capabilities): - Created TypeScript interfaces for all new/updated models. - Set up `listDetailStore.ts` with an action to handle `settleExpenseSplit` (API call is a placeholder) and refresh data. - Created `SettleShareModal.vue` component for payment confirmation. - Added unit tests for the new modal and store logic. - Updated `ListDetailPage.vue` to display detailed expense/share statuses and settlement activities. - `mitlist_doc.md` updated to reflect all backend changes and current frontend status. - A `TODO.md` (implicitly within `mitlist_doc.md`'s new section) outlines necessary manual frontend integrations for `api.ts` and `ListDetailPage.vue` to complete the 'Settle Share' UI flow. This set of changes provides the core backend infrastructure for precise expense share tracking and settlement, and lays the groundwork for full frontend integration.
369 lines
17 KiB
Python
369 lines
17 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from datetime import datetime, timezone
|
|
from typing import List as PyList, Optional
|
|
|
|
from app.crud.expense import (
|
|
create_expense,
|
|
get_expense_by_id,
|
|
get_expenses_for_list,
|
|
get_expenses_for_group,
|
|
update_expense,
|
|
delete_expense,
|
|
get_users_for_splitting
|
|
)
|
|
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead
|
|
from app.models import (
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
User as UserModel,
|
|
List as ListModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel,
|
|
Item as ItemModel,
|
|
SplitTypeEnum,
|
|
ExpenseOverallStatusEnum, # Added
|
|
ExpenseSplitStatusEnum # Added
|
|
)
|
|
from app.core.exceptions import (
|
|
ListNotFoundError,
|
|
GroupNotFoundError,
|
|
UserNotFoundError,
|
|
InvalidOperationError,
|
|
ExpenseNotFoundError,
|
|
DatabaseTransactionError,
|
|
ConflictError
|
|
)
|
|
|
|
# General Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.begin_nested = AsyncMock() # For nested transactions within functions
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.delete = MagicMock()
|
|
session.execute = AsyncMock()
|
|
session.get = AsyncMock()
|
|
session.flush = AsyncMock()
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
# Mock session.begin() to return an async context manager
|
|
mock_transaction_context = AsyncMock()
|
|
session.begin = MagicMock(return_value=mock_transaction_context)
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def basic_user_model():
|
|
return UserModel(id=1, name="Test User", email="test@example.com", version=1)
|
|
|
|
@pytest.fixture
|
|
def another_user_model():
|
|
return UserModel(id=2, name="Another User", email="another@example.com", version=1)
|
|
|
|
@pytest.fixture
|
|
def basic_group_model(basic_user_model, another_user_model):
|
|
group = GroupModel(id=1, name="Test Group", version=1)
|
|
# Simulate member_associations for get_users_for_splitting if needed directly
|
|
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)]
|
|
return group
|
|
|
|
@pytest.fixture
|
|
def basic_list_model(basic_group_model, basic_user_model):
|
|
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
|
|
return ExpenseCreate(
|
|
description="Grocery run",
|
|
total_amount=Decimal("30.00"),
|
|
currency="USD",
|
|
expense_date=datetime.now(timezone.utc).date(),
|
|
split_type=SplitTypeEnum.EQUAL,
|
|
list_id=basic_list_model.id,
|
|
group_id=None, # Derived from list
|
|
item_id=None,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=None
|
|
)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
|
|
return ExpenseCreate(
|
|
description="Movies",
|
|
total_amount=Decimal("50.00"),
|
|
currency="USD",
|
|
expense_date=datetime.now(timezone.utc).date(),
|
|
split_type=SplitTypeEnum.EQUAL,
|
|
list_id=None,
|
|
group_id=basic_group_model.id,
|
|
item_id=None,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=None
|
|
)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
|
|
return ExpenseCreate(
|
|
description="Dinner",
|
|
total_amount=Decimal("100.00"),
|
|
expense_date=datetime.now(timezone.utc).date(),
|
|
currency="USD",
|
|
split_type=SplitTypeEnum.EXACT_AMOUNTS,
|
|
group_id=basic_group_model.id,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=[
|
|
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
|
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
|
|
]
|
|
)
|
|
|
|
@pytest.fixture
|
|
def expense_update_data():
|
|
return ExpenseUpdate(
|
|
description="Updated Dinner",
|
|
total_amount=Decimal("120.00"),
|
|
version=1 # Ensure version is provided for updates
|
|
)
|
|
|
|
@pytest.fixture
|
|
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model):
|
|
expense = ExpenseModel(
|
|
id=1,
|
|
description=expense_create_data_equal_split_group_ctx.description,
|
|
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
|
|
currency=expense_create_data_equal_split_group_ctx.currency,
|
|
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
|
|
split_type=expense_create_data_equal_split_group_ctx.split_type,
|
|
list_id=expense_create_data_equal_split_group_ctx.list_id,
|
|
group_id=expense_create_data_equal_split_group_ctx.group_id,
|
|
group=basic_group_model, # Link to group fixture
|
|
item_id=expense_create_data_equal_split_group_ctx.item_id,
|
|
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
|
|
created_by_user_id=basic_user_model.id,
|
|
paid_by=basic_user_model,
|
|
created_by_user=basic_user_model,
|
|
version=1,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc)
|
|
)
|
|
# Simulate splits for an existing expense
|
|
expense.splits = [
|
|
ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
|
|
ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2
|
|
]
|
|
return expense
|
|
|
|
# Tests for get_users_for_splitting
|
|
@pytest.mark.asyncio
|
|
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
|
|
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
|
|
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
|
|
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
|
|
|
mock_db_session.get.return_value = basic_group_model # Mock get for group
|
|
|
|
users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id)
|
|
assert len(users) == 2
|
|
assert basic_user_model in users
|
|
assert another_user_model in users
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model):
|
|
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
|
|
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
|
|
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
|
basic_list_model.group = basic_group_model # Ensure list is associated with the group
|
|
|
|
mock_db_session.get.return_value = basic_list_model # Mock get for list
|
|
|
|
users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id)
|
|
assert len(users) == 2
|
|
assert basic_user_model in users
|
|
assert another_user_model in users
|
|
|
|
# --- create_expense Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
|
|
# Setup mocks
|
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group
|
|
|
|
# Mock get_users_for_splitting directly
|
|
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
|
|
mock_get_users.return_value = [basic_user_model, another_user_model]
|
|
|
|
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
|
|
if isinstance(instance, ExpenseModel):
|
|
instance.id = 1 # Simulate ID assignment after flush
|
|
instance.version = 1
|
|
instance.created_at = datetime.now(timezone.utc)
|
|
instance.updated_at = datetime.now(timezone.utc)
|
|
# Simulate splits being added to the session and linked by refresh
|
|
instance.splits = [
|
|
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
|
|
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1)
|
|
]
|
|
return None
|
|
mock_db_session.refresh.side_effect = mock_refresh
|
|
|
|
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id)
|
|
|
|
mock_db_session.add.assert_called()
|
|
mock_db_session.flush.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once()
|
|
assert created_expense is not None
|
|
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
|
|
assert created_expense.split_type == SplitTypeEnum.EQUAL
|
|
assert len(created_expense.splits) == 2
|
|
|
|
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
for split in created_expense.splits:
|
|
assert split.owed_amount == expected_amount_per_user
|
|
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
|
|
|
|
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
|
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split
|
|
|
|
async def mock_refresh(instance, attribute_names=None, with_for_update=None):
|
|
if isinstance(instance, ExpenseModel):
|
|
instance.id = 2
|
|
instance.version = 1
|
|
instance.splits = [
|
|
ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
|
ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00"))
|
|
]
|
|
return None
|
|
mock_db_session.refresh.side_effect = mock_refresh
|
|
|
|
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id)
|
|
|
|
mock_db_session.add.assert_called()
|
|
mock_db_session.flush.assert_called_once()
|
|
assert created_expense is not None
|
|
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
|
|
assert len(created_expense.splits) == 2
|
|
assert created_expense.splits[0].owed_amount == Decimal("60.00")
|
|
assert created_expense.splits[1].owed_amount == Decimal("40.00")
|
|
for split in created_expense.splits:
|
|
assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
|
|
|
|
assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
|
|
mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen
|
|
with pytest.raises(UserNotFoundError):
|
|
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
|
|
mock_db_session.get.return_value = basic_user_model # Payer found
|
|
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
|
|
expense_data.list_id = None
|
|
expense_data.group_id = None
|
|
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
|
|
await create_expense(mock_db_session, expense_data, basic_user_model.id)
|
|
|
|
# --- get_expense_by_id Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
|
|
mock_db_session.get.return_value = db_expense_model
|
|
expense = await get_expense_by_id(mock_db_session, db_expense_model.id)
|
|
assert expense is not None
|
|
assert expense.id == db_expense_model.id
|
|
mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[
|
|
MagicMock(), MagicMock(), MagicMock()
|
|
]) # Adjust based on actual options used in get_expense_by_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_expense_by_id_not_found(mock_db_session):
|
|
mock_db_session.get.return_value = None
|
|
expense = await get_expense_by_id(mock_db_session, 999)
|
|
assert expense is None
|
|
|
|
# --- get_expenses_for_list Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id)
|
|
assert len(expenses) == 1
|
|
assert expenses[0].id == db_expense_model.id
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# --- get_expenses_for_group Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id)
|
|
assert len(expenses) == 1
|
|
assert expenses[0].id == db_expense_model.id
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# --- update_expense Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
|
|
expense_update_data.version = db_expense_model.version # Match version
|
|
|
|
# Simulate that the db_expense_model is returned by session.get
|
|
mock_db_session.get.return_value = db_expense_model
|
|
|
|
updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
|
|
|
|
mock_db_session.add.assert_called_once_with(db_expense_model)
|
|
mock_db_session.flush.assert_called_once()
|
|
mock_db_session.refresh.assert_called_once_with(db_expense_model)
|
|
assert updated_expense.description == expense_update_data.description
|
|
assert updated_expense.total_amount == expense_update_data.total_amount
|
|
assert updated_expense.version == db_expense_model.version # Version incremented by the function
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model):
|
|
mock_db_session.get.return_value = None # Expense not found
|
|
with pytest.raises(ExpenseNotFoundError):
|
|
await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
|
|
expense_update_data.version = db_expense_model.version + 1 # Create version mismatch
|
|
mock_db_session.get.return_value = db_expense_model
|
|
with pytest.raises(ConflictError):
|
|
await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
|
|
mock_db_session.rollback.assert_called_once()
|
|
|
|
# --- delete_expense Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model):
|
|
mock_db_session.get.return_value = db_expense_model # Simulate expense found
|
|
|
|
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
|
|
|
|
mock_db_session.delete.assert_called_once_with(db_expense_model)
|
|
# Assuming delete_expense uses session.begin() and commits
|
|
mock_db_session.begin().commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_expense_not_found(mock_db_session, basic_user_model):
|
|
mock_db_session.get.return_value = None # Expense not found
|
|
with pytest.raises(ExpenseNotFoundError):
|
|
await delete_expense(mock_db_session, 999, basic_user_model.id)
|
|
mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model):
|
|
mock_db_session.get.return_value = db_expense_model
|
|
mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
|
|
with pytest.raises(DatabaseTransactionError):
|
|
await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
|
|
mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context |