319 lines
14 KiB
Python
319 lines
14 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from datetime import datetime, timezone
|
|
from typing import List as PyList, Optional
|
|
|
|
from app.crud.expense import (
|
|
create_expense,
|
|
get_expense_by_id,
|
|
get_expenses_for_list,
|
|
get_expenses_for_group,
|
|
update_expense, # Assuming update_expense exists
|
|
delete_expense, # Assuming delete_expense exists
|
|
get_users_for_splitting # Helper, might test indirectly
|
|
)
|
|
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
|
|
from app.models import (
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
User as UserModel,
|
|
List as ListModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel,
|
|
Item as ItemModel,
|
|
SplitTypeEnum
|
|
)
|
|
from app.core.exceptions import (
|
|
ListNotFoundError,
|
|
GroupNotFoundError,
|
|
UserNotFoundError,
|
|
InvalidOperationError
|
|
)
|
|
|
|
# General Fixtures
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
session = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
session.add = MagicMock()
|
|
session.delete = MagicMock()
|
|
session.execute = AsyncMock()
|
|
session.get = AsyncMock()
|
|
session.flush = AsyncMock() # create_expense uses flush
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def basic_user_model():
|
|
return UserModel(id=1, name="Test User", email="test@example.com")
|
|
|
|
@pytest.fixture
|
|
def another_user_model():
|
|
return UserModel(id=2, name="Another User", email="another@example.com")
|
|
|
|
@pytest.fixture
|
|
def basic_group_model():
|
|
group = GroupModel(id=1, name="Test Group")
|
|
# Simulate member_associations for get_users_for_splitting if needed directly
|
|
# group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
|
|
return group
|
|
|
|
@pytest.fixture
|
|
def basic_list_model(basic_group_model, basic_user_model):
|
|
return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
|
|
return ExpenseCreate(
|
|
description="Grocery run",
|
|
total_amount=Decimal("30.00"),
|
|
currency="USD",
|
|
expense_date=datetime.now(timezone.utc),
|
|
split_type=SplitTypeEnum.EQUAL,
|
|
list_id=basic_list_model.id,
|
|
group_id=None, # Derived from list
|
|
item_id=None,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=None
|
|
)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
|
|
return ExpenseCreate(
|
|
description="Movies",
|
|
total_amount=Decimal("50.00"),
|
|
currency="USD",
|
|
expense_date=datetime.now(timezone.utc),
|
|
split_type=SplitTypeEnum.EQUAL,
|
|
list_id=None,
|
|
group_id=basic_group_model.id,
|
|
item_id=None,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=None
|
|
)
|
|
|
|
@pytest.fixture
|
|
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
|
|
return ExpenseCreate(
|
|
description="Dinner",
|
|
total_amount=Decimal("100.00"),
|
|
split_type=SplitTypeEnum.EXACT_AMOUNTS,
|
|
group_id=basic_group_model.id,
|
|
paid_by_user_id=basic_user_model.id,
|
|
splits_in=[
|
|
ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
|
|
ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
|
|
]
|
|
)
|
|
|
|
@pytest.fixture
|
|
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
|
|
return ExpenseModel(
|
|
id=1,
|
|
description=expense_create_data_equal_split_group_ctx.description,
|
|
total_amount=expense_create_data_equal_split_group_ctx.total_amount,
|
|
currency=expense_create_data_equal_split_group_ctx.currency,
|
|
expense_date=expense_create_data_equal_split_group_ctx.expense_date,
|
|
split_type=expense_create_data_equal_split_group_ctx.split_type,
|
|
list_id=expense_create_data_equal_split_group_ctx.list_id,
|
|
group_id=expense_create_data_equal_split_group_ctx.group_id,
|
|
item_id=expense_create_data_equal_split_group_ctx.item_id,
|
|
paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
|
|
created_by_user_id=basic_user_model.id,
|
|
paid_by=basic_user_model, # Assuming paid_by relation is loaded
|
|
created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
|
|
# splits would be populated after creation usually
|
|
version=1
|
|
)
|
|
|
|
# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
|
|
@pytest.mark.asyncio
|
|
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
|
|
# Setup group with members
|
|
user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
|
|
user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
|
|
basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
|
|
|
|
mock_execute = AsyncMock()
|
|
mock_execute.scalars.return_value.first.return_value = basic_group_model
|
|
mock_db_session.execute.return_value = mock_execute
|
|
|
|
users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
|
|
assert len(users) == 2
|
|
assert basic_user_model in users
|
|
assert another_user_model in users
|
|
|
|
# --- create_expense Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
|
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
|
|
|
|
# Mock get_users_for_splitting call within create_expense
|
|
# This is a bit tricky as it's an internal call. Patching is an option.
|
|
with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
|
|
mock_get_users.return_value = [basic_user_model, another_user_model]
|
|
|
|
created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)
|
|
|
|
mock_db_session.add.assert_called()
|
|
mock_db_session.flush.assert_called_once()
|
|
# mock_db_session.commit.assert_called_once() # create_expense does not commit itself
|
|
# mock_db_session.refresh.assert_called_once() # create_expense does not refresh itself
|
|
|
|
assert created_expense is not None
|
|
assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
|
|
assert created_expense.split_type == SplitTypeEnum.EQUAL
|
|
assert len(created_expense.splits) == 2 # Expect splits to be added to the model instance
|
|
|
|
# Check split amounts
|
|
expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
for split in created_expense.splits:
|
|
assert split.owed_amount == expected_amount_per_user
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
|
|
mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # Payer, Group
|
|
|
|
# Mock the select for user validation in exact splits
|
|
mock_user_select_result = AsyncMock()
|
|
mock_user_select_result.all.return_value = [(basic_user_model.id,), (another_user_model.id,)] # Simulate (id,) tuples
|
|
# To make it behave like scalars().all() that returns a list of IDs:
|
|
# We need to mock the scalars().all() part, or the whole execute chain for user validation.
|
|
# A simpler way for this specific case might be to mock the select for User.id
|
|
mock_execute_user_ids = AsyncMock()
|
|
# Construct a mock result that mimics what `await db.execute(select(UserModel.id)...)` would give then process
|
|
# It's a bit involved, usually `found_user_ids = {row[0] for row in user_results}`
|
|
# Let's assume the select returns a list of Row objects or tuples with one element
|
|
mock_user_ids_result_proxy = MagicMock()
|
|
mock_user_ids_result_proxy.__iter__.return_value = iter([(basic_user_model.id,), (another_user_model.id,)])
|
|
mock_db_session.execute.return_value = mock_user_ids_result_proxy
|
|
|
|
created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)
|
|
|
|
mock_db_session.add.assert_called()
|
|
mock_db_session.flush.assert_called_once()
|
|
assert created_expense is not None
|
|
assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
|
|
assert len(created_expense.splits) == 2
|
|
assert created_expense.splits[0].owed_amount == Decimal("60.00")
|
|
assert created_expense.splits[1].owed_amount == Decimal("40.00")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
|
|
mock_db_session.get.return_value = None # Payer not found
|
|
with pytest.raises(UserNotFoundError):
|
|
await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
|
|
mock_db_session.get.return_value = basic_user_model # Payer found
|
|
expense_data = expense_create_data_equal_split_group_ctx.model_copy()
|
|
expense_data.list_id = None
|
|
expense_data.group_id = None
|
|
with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
|
|
await create_expense(mock_db_session, expense_data, 1)
|
|
|
|
# --- get_expense_by_id Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = db_expense_model
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expense = await get_expense_by_id(mock_db_session, 1)
|
|
assert expense is not None
|
|
assert expense.id == 1
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_expense_by_id_not_found(mock_db_session):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.first.return_value = None
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expense = await get_expense_by_id(mock_db_session, 999)
|
|
assert expense is None
|
|
|
|
# --- get_expenses_for_list Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expenses = await get_expenses_for_list(mock_db_session, list_id=1)
|
|
assert len(expenses) == 1
|
|
assert expenses[0].id == db_expense_model.id
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# --- get_expenses_for_group Tests ---
|
|
@pytest.mark.asyncio
|
|
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
|
|
mock_result = AsyncMock()
|
|
mock_result.scalars.return_value.all.return_value = [db_expense_model]
|
|
mock_db_session.execute.return_value = mock_result
|
|
|
|
expenses = await get_expenses_for_group(mock_db_session, group_id=1)
|
|
assert len(expenses) == 1
|
|
assert expenses[0].id == db_expense_model.id
|
|
mock_db_session.execute.assert_called_once()
|
|
|
|
# --- Stubs for update_expense and delete_expense ---
|
|
# These will need more details once the actual implementation of update/delete is clear
|
|
# For example, how splits are handled on update, versioning, etc.
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_expense_stub(mock_db_session):
|
|
# Placeholder: Test logic for update_expense will be more complex
|
|
# Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
|
|
# Also depends on what fields are updatable and how splits are managed.
|
|
expense_to_update = MagicMock(spec=ExpenseModel)
|
|
expense_to_update.version = 1
|
|
update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
|
|
|
|
# Simulate the update_expense function behavior
|
|
# For example, if it loads the expense, modifies, commits, refreshes:
|
|
# mock_db_session.get.return_value = expense_to_update
|
|
# updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
|
|
# assert updated_expense.description == "New description"
|
|
# mock_db_session.commit.assert_called_once()
|
|
# mock_db_session.refresh.assert_called_once()
|
|
pass # Replace with actual test logic
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_expense_stub(mock_db_session):
|
|
# Placeholder: Test logic for delete_expense
|
|
# Needs an existing expense object and mocking of delete/commit
|
|
# Also, consider implications (e.g., are splits deleted?)
|
|
expense_to_delete = MagicMock(spec=ExpenseModel)
|
|
expense_to_delete.id = 1
|
|
expense_to_delete.version = 1
|
|
|
|
# Simulate delete_expense behavior
|
|
# mock_db_session.get.return_value = expense_to_delete # If it re-fetches
|
|
# await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
|
|
# mock_db_session.delete.assert_called_once_with(expense_to_delete)
|
|
# mock_db_session.commit.assert_called_once()
|
|
pass # Replace with actual test logic
|
|
|
|
# TODO: Add more tests for create_expense covering:
|
|
# - List context success
|
|
# - Percentage, Shares, Item-based splits
|
|
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
|
|
# - Validation of list_id/group_id consistency
|
|
# - User not found in splits_in
|
|
# - Item not found for ITEM_BASED split
|
|
|
|
# TODO: Flesh out update_expense tests:
|
|
# - Success case
|
|
# - Version mismatch
|
|
# - Trying to update immutable fields
|
|
# - How splits are handled (recalculated, deleted/recreated, or not changeable)
|
|
|
|
# TODO: Flesh out delete_expense tests:
|
|
# - Success case
|
|
# - Version mismatch (if applicable)
|
|
# - Ensure associated splits are also deleted (cascade behavior) |