![google-labs-jules[bot]](/assets/img/avatar_default.png)
Backend: - Added `SettlementActivity` model to track payments against specific expense shares. - Added `status` and `paid_at` to `ExpenseSplit` model. - Added `overall_settlement_status` to `Expense` model. - Implemented CRUD for `SettlementActivity`, including logic to update parent expense/split statuses. - Updated `Expense` CRUD to initialize new status fields. - Defined Pydantic schemas for `SettlementActivity` and updated `Expense/ExpenseSplit` schemas. - Exposed API endpoints for creating/listing settlement activities and settling shares. - Adjusted group balance summary logic to include settlement activities. - Added comprehensive backend unit and API tests for new functionality. Frontend (Foundation & TODOs due to my current capabilities): - Created TypeScript interfaces for all new/updated models. - Set up `listDetailStore.ts` with an action to handle `settleExpenseSplit` (API call is a placeholder) and refresh data. - Created `SettleShareModal.vue` component for payment confirmation. - Added unit tests for the new modal and store logic. - Updated `ListDetailPage.vue` to display detailed expense/share statuses and settlement activities. - `mitlist_doc.md` updated to reflect all backend changes and current frontend status. - A `TODO.md` (implicitly within `mitlist_doc.md`'s new section) outlines necessary manual frontend integrations for `api.ts` and `ListDetailPage.vue` to complete the 'Settle Share' UI flow. This set of changes provides the core backend infrastructure for precise expense share tracking and settlement, and lays the groundwork for full frontend integration.
633 lines
29 KiB
Python
633 lines
29 KiB
Python
# app/crud/expense.py
|
|
import logging # Add logging import
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
|
|
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
|
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
|
from datetime import datetime, timezone # Added timezone
|
|
|
|
from app.models import (
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
User as UserModel,
|
|
List as ListModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel,
|
|
SplitTypeEnum,
|
|
Item as ItemModel,
|
|
ExpenseOverallStatusEnum, # Added
|
|
ExpenseSplitStatusEnum, # Added
|
|
)
|
|
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
|
|
from app.core.exceptions import (
|
|
# Using existing specific exceptions where possible
|
|
ListNotFoundError,
|
|
GroupNotFoundError,
|
|
UserNotFoundError,
|
|
InvalidOperationError, # Import the new exception
|
|
DatabaseConnectionError, # Added
|
|
DatabaseIntegrityError, # Added
|
|
DatabaseQueryError, # Added
|
|
DatabaseTransactionError,# Added
|
|
ExpenseOperationError # Added specific exception
|
|
)
|
|
|
|
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
|
|
# This should be a proper HTTPException subclass if used in API layer
|
|
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
|
|
# pass
|
|
|
|
logger = logging.getLogger(__name__) # Initialize logger
|
|
|
|
def _round_money(amount: Decimal) -> Decimal:
|
|
"""Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
|
|
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
|
|
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
|
|
"""
|
|
Determines the list of users an expense should be split amongst.
|
|
Priority: Group members (if group_id), then List's group members or creator (if list_id).
|
|
Fallback to only the payer if no other context yields users.
|
|
"""
|
|
users_to_split_with: PyList[UserModel] = []
|
|
processed_user_ids = set()
|
|
|
|
async def _add_user(user: Optional[UserModel]):
|
|
if user and user.id not in processed_user_ids:
|
|
users_to_split_with.append(user)
|
|
processed_user_ids.add(user.id)
|
|
|
|
if expense_group_id:
|
|
group_result = await db.execute(
|
|
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
|
|
.where(GroupModel.id == expense_group_id)
|
|
)
|
|
group = group_result.scalars().first()
|
|
if not group:
|
|
raise GroupNotFoundError(expense_group_id)
|
|
for assoc in group.member_associations:
|
|
await _add_user(assoc.user)
|
|
|
|
elif expense_list_id: # Only if group_id was not primary context
|
|
list_result = await db.execute(
|
|
select(ListModel)
|
|
.options(
|
|
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
|
selectinload(ListModel.creator)
|
|
)
|
|
.where(ListModel.id == expense_list_id)
|
|
)
|
|
db_list = list_result.scalars().first()
|
|
if not db_list:
|
|
raise ListNotFoundError(expense_list_id)
|
|
|
|
if db_list.group:
|
|
for assoc in db_list.group.member_associations:
|
|
await _add_user(assoc.user)
|
|
elif db_list.creator:
|
|
await _add_user(db_list.creator)
|
|
|
|
if not users_to_split_with:
|
|
payer_user = await db.get(UserModel, expense_paid_by_user_id)
|
|
if not payer_user:
|
|
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
|
|
raise UserNotFoundError(user_id=expense_paid_by_user_id)
|
|
await _add_user(payer_user)
|
|
|
|
if not users_to_split_with:
|
|
# This should ideally not be reached if payer is always a fallback
|
|
raise InvalidOperationError("Could not determine any users for splitting the expense.")
|
|
|
|
return users_to_split_with
|
|
|
|
|
|
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
|
|
"""Creates a new expense and its associated splits.
|
|
|
|
Args:
|
|
db: Database session
|
|
expense_in: Expense creation data
|
|
current_user_id: ID of the user creating the expense
|
|
|
|
Returns:
|
|
The created expense with splits
|
|
|
|
Raises:
|
|
UserNotFoundError: If payer or split users don't exist
|
|
ListNotFoundError: If specified list doesn't exist
|
|
GroupNotFoundError: If specified group doesn't exist
|
|
InvalidOperationError: For various validation failures
|
|
"""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
# 1. Validate payer
|
|
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
|
if not payer:
|
|
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
|
|
|
|
# 2. Context Resolution and Validation (now part of the transaction)
|
|
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
|
|
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
|
|
|
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
|
# Further validation for item_id if provided
|
|
db_item_instance = None
|
|
if expense_in.item_id:
|
|
db_item_instance = await db.get(ItemModel, expense_in.item_id)
|
|
if not db_item_instance:
|
|
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
|
|
# Potentially link item's list/group if not already set on expense_in
|
|
if db_item_instance.list_id and not expense_in.list_id:
|
|
expense_in.list_id = db_item_instance.list_id
|
|
# Re-resolve context if list_id was derived from item
|
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
|
|
|
# 3. Create the ExpenseModel instance
|
|
db_expense = ExpenseModel(
|
|
description=expense_in.description,
|
|
total_amount=_round_money(expense_in.total_amount),
|
|
currency=expense_in.currency or "USD",
|
|
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
|
split_type=expense_in.split_type,
|
|
list_id=expense_in.list_id,
|
|
group_id=final_group_id, # Use resolved group_id
|
|
item_id=expense_in.item_id,
|
|
paid_by_user_id=expense_in.paid_by_user_id,
|
|
created_by_user_id=current_user_id,
|
|
overall_settlement_status=ExpenseOverallStatusEnum.unpaid # Explicitly set default status
|
|
)
|
|
db.add(db_expense)
|
|
await db.flush() # Get expense ID
|
|
|
|
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
|
|
splits_to_create = await _generate_expense_splits(
|
|
db=db,
|
|
expense_model=db_expense,
|
|
expense_in=expense_in,
|
|
current_user_id=current_user_id # Pass for item-based splits needing creator info
|
|
)
|
|
|
|
for split_model in splits_to_create:
|
|
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
|
|
db.add_all(splits_to_create)
|
|
await db.flush() # Persist splits
|
|
|
|
# 5. Re-fetch the expense with all necessary relationships for the response
|
|
stmt = (
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.id == db_expense.id)
|
|
.options(
|
|
selectinload(ExpenseModel.paid_by_user),
|
|
selectinload(ExpenseModel.created_by_user), # If you have this relationship
|
|
selectinload(ExpenseModel.list),
|
|
selectinload(ExpenseModel.group),
|
|
selectinload(ExpenseModel.item),
|
|
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_expense = result.scalar_one_or_none()
|
|
|
|
if loaded_expense is None:
|
|
# The context manager will handle rollback if an exception is raised.
|
|
# await transaction.rollback() # Should be handled by context manager
|
|
raise ExpenseOperationError("Failed to load expense after creation.")
|
|
|
|
# await transaction.commit() # Explicit commit removed, context manager handles it.
|
|
return loaded_expense
|
|
|
|
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
|
# These are business logic validation errors, re-raise them.
|
|
# If a transaction was started, the context manager handles rollback.
|
|
raise
|
|
except IntegrityError as e:
|
|
# Context manager handles rollback.
|
|
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
# Context manager handles rollback.
|
|
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
|
|
|
|
|
|
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
|
|
"""Resolves and validates the expense's context (list and group).
|
|
|
|
Returns the final group_id for the expense after validation.
|
|
"""
|
|
final_group_id = expense_in.group_id
|
|
|
|
# If list_id is provided, validate it and potentially derive group_id
|
|
if expense_in.list_id:
|
|
list_obj = await db.get(ListModel, expense_in.list_id)
|
|
if not list_obj:
|
|
raise ListNotFoundError(expense_in.list_id)
|
|
|
|
# If list belongs to a group, verify consistency or inherit group_id
|
|
if list_obj.group_id:
|
|
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
|
raise InvalidOperationError(
|
|
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
|
|
f"but expense was specified for group {expense_in.group_id}."
|
|
)
|
|
final_group_id = list_obj.group_id # Prioritize list's group
|
|
|
|
# If only group_id is provided (no list_id), validate group_id
|
|
elif final_group_id:
|
|
group_obj = await db.get(GroupModel, final_group_id)
|
|
if not group_obj:
|
|
raise GroupNotFoundError(final_group_id)
|
|
|
|
return final_group_id
|
|
|
|
|
|
async def _generate_expense_splits(
|
|
db: AsyncSession,
|
|
expense_model: ExpenseModel,
|
|
expense_in: ExpenseCreate,
|
|
**kwargs: Any
|
|
) -> PyList[ExpenseSplitModel]:
|
|
"""Generates appropriate expense splits based on split type."""
|
|
|
|
splits_to_create: PyList[ExpenseSplitModel] = []
|
|
|
|
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
|
|
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
|
|
|
|
# Create splits based on the split type
|
|
if expense_in.split_type == SplitTypeEnum.EQUAL:
|
|
splits_to_create = await _create_equal_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
|
|
splits_to_create = await _create_exact_amount_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
|
|
splits_to_create = await _create_percentage_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.SHARES:
|
|
splits_to_create = await _create_shares_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
|
|
splits_to_create = await _create_item_based_splits(**common_args)
|
|
|
|
else:
|
|
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
|
|
|
|
if not splits_to_create:
|
|
raise InvalidOperationError("No expense splits were generated.")
|
|
|
|
return splits_to_create
|
|
|
|
|
|
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates equal splits among users."""
|
|
|
|
users_for_splitting = await get_users_for_splitting(
|
|
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
|
|
)
|
|
if not users_for_splitting:
|
|
raise InvalidOperationError("No users found for EQUAL split.")
|
|
|
|
num_users = len(users_for_splitting)
|
|
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
|
|
remainder = expense_model.total_amount - (amount_per_user * num_users)
|
|
|
|
splits = []
|
|
for i, user in enumerate(users_for_splitting):
|
|
split_amount = amount_per_user
|
|
if i == 0 and remainder != Decimal('0'):
|
|
split_amount = round_money_func(amount_per_user + remainder)
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=user.id,
|
|
owed_amount=split_amount,
|
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
|
))
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits with exact amounts."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
current_total = Decimal("0.00")
|
|
splits = []
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if split_in.owed_amount <= Decimal('0'):
|
|
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
|
|
|
|
rounded_amount = round_money_func(split_in.owed_amount)
|
|
current_total += rounded_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=rounded_amount,
|
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
|
))
|
|
|
|
if round_money_func(current_total) != expense_model.total_amount:
|
|
raise InvalidOperationError(
|
|
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
|
|
)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on percentages."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
total_percentage = Decimal("0.00")
|
|
current_total = Decimal("0.00")
|
|
splits = []
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
|
|
raise InvalidOperationError(
|
|
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
|
|
)
|
|
|
|
total_percentage += split_in.share_percentage
|
|
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
|
|
current_total += owed_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=owed_amount,
|
|
share_percentage=split_in.share_percentage,
|
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
|
))
|
|
|
|
if round_money_func(total_percentage) != Decimal("100.00"):
|
|
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
|
|
|
|
# Adjust for rounding differences
|
|
if current_total != expense_model.total_amount and splits:
|
|
diff = expense_model.total_amount - current_total
|
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on shares."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for SHARES split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
# Calculate total shares
|
|
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
|
|
if total_shares == 0:
|
|
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
|
|
|
|
splits = []
|
|
current_total = Decimal("0.00")
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if not (split_in.share_units and split_in.share_units > 0):
|
|
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
|
|
|
|
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
|
|
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
|
|
current_total += owed_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=owed_amount,
|
|
share_units=split_in.share_units,
|
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
|
))
|
|
|
|
# Adjust for rounding differences
|
|
if current_total != expense_model.total_amount and splits:
|
|
diff = expense_model.total_amount - current_total
|
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on items in a shopping list."""
|
|
|
|
if not expense_model.list_id:
|
|
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
|
|
|
|
if expense_in.splits_in:
|
|
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
|
|
|
|
# Build query to fetch relevant items
|
|
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
|
|
if expense_model.item_id:
|
|
items_query = items_query.where(ItemModel.id == expense_model.item_id)
|
|
else:
|
|
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
|
|
|
|
# Load items with their adders
|
|
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
|
|
relevant_items = items_result.scalars().all()
|
|
|
|
if not relevant_items:
|
|
error_msg = (
|
|
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
|
|
if expense_model.item_id else
|
|
f"List {expense_model.list_id} has no priced items to base the expense on."
|
|
)
|
|
raise InvalidOperationError(error_msg)
|
|
|
|
# Aggregate owed amounts by user
|
|
calculated_total = Decimal("0.00")
|
|
user_owed_amounts = defaultdict(Decimal)
|
|
processed_items = 0
|
|
|
|
for item in relevant_items:
|
|
if item.price is None or item.price <= Decimal("0"):
|
|
if expense_model.item_id:
|
|
raise InvalidOperationError(
|
|
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
|
|
)
|
|
continue
|
|
|
|
if not item.added_by_user:
|
|
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
|
|
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
|
|
|
|
calculated_total += item.price
|
|
user_owed_amounts[item.added_by_user.id] += item.price
|
|
processed_items += 1
|
|
|
|
if processed_items == 0:
|
|
raise InvalidOperationError(
|
|
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
|
|
)
|
|
|
|
# Validate total matches calculated total
|
|
if round_money_func(calculated_total) != expense_model.total_amount:
|
|
raise InvalidOperationError(
|
|
f"Expense total amount ({expense_model.total_amount}) does not match the "
|
|
f"calculated total from item prices ({calculated_total})."
|
|
)
|
|
|
|
# Create splits based on aggregated amounts
|
|
splits = []
|
|
for user_id, owed_amount in user_owed_amounts.items():
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=user_id,
|
|
owed_amount=round_money_func(owed_amount),
|
|
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
|
|
))
|
|
|
|
return splits
|
|
|
|
|
|
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
|
|
"""Validates that all users in the splits exist."""
|
|
|
|
user_ids_in_split = [s.user_id for s in splits_in]
|
|
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
|
|
found_user_ids = {row[0] for row in user_results}
|
|
|
|
if len(found_user_ids) != len(user_ids_in_split):
|
|
missing_user_ids = set(user_ids_in_split) - found_user_ids
|
|
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
|
|
|
|
|
|
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.options(
|
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
|
selectinload(ExpenseModel.paid_by_user),
|
|
selectinload(ExpenseModel.list),
|
|
selectinload(ExpenseModel.group),
|
|
selectinload(ExpenseModel.item)
|
|
)
|
|
.where(ExpenseModel.id == expense_id)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.list_id == list_id)
|
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.group_id == group_id)
|
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
|
|
"""
|
|
Updates an existing expense.
|
|
Only allows updates to description, currency, and expense_date to avoid split complexities.
|
|
Requires version matching for optimistic locking.
|
|
"""
|
|
if expense_db.version != expense_in.version:
|
|
raise InvalidOperationError(
|
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
|
|
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
|
|
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
|
|
)
|
|
|
|
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
|
|
|
|
# Fields that are safe to update without affecting splits or core logic
|
|
allowed_to_update = {"description", "currency", "expense_date"}
|
|
|
|
updated_something = False
|
|
for field, value in update_data.items():
|
|
if field in allowed_to_update:
|
|
setattr(expense_db, field, value)
|
|
updated_something = True
|
|
else:
|
|
# If any other field is present in the update payload, it's an invalid operation for this simple update
|
|
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
|
|
|
|
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
|
|
# No actual updatable fields were provided in the payload, even if others (like version) were.
|
|
# This could be a non-issue, or an indication of a misuse of the endpoint.
|
|
# For now, if only version was sent, we still increment if it matched.
|
|
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
|
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
expense_db.version += 1
|
|
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
|
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
|
|
|
|
await db.flush() # Persist changes to the DB and run constraints
|
|
await db.refresh(expense_db) # Refresh the object from the DB
|
|
return expense_db
|
|
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
|
|
raise
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
|
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
|
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
|
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
|
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
|
|
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
|
|
|
|
|
|
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
|
"""
|
|
Deletes an expense. Requires version matching if expected_version is provided.
|
|
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
|
|
"""
|
|
if expected_version is not None and expense_db.version != expected_version:
|
|
raise InvalidOperationError(
|
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
|
|
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
|
|
# status_code=status.HTTP_409_CONFLICT
|
|
)
|
|
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
await db.delete(expense_db)
|
|
await db.flush() # Ensure the delete operation is sent to the database
|
|
except InvalidOperationError: # Re-raise validation errors
|
|
raise
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
|
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
|
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
|
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
|
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
|
|
return None
|
|
|
|
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
|
# For API endpoints, these should be translated to appropriate HTTPExceptions.
|
|
# Ensure app.core.exceptions has proper HTTP error classes if needed. |