mitlist/be/app/crud/expense.py
google-labs-jules[bot] f1152c5745 feat: Implement traceable expense splitting and settlement activities
Backend:
- Added `SettlementActivity` model to track payments against specific expense shares.
- Added `status` and `paid_at` to `ExpenseSplit` model.
- Added `overall_settlement_status` to `Expense` model.
- Implemented CRUD for `SettlementActivity`, including logic to update parent expense/split statuses.
- Updated `Expense` CRUD to initialize new status fields.
- Defined Pydantic schemas for `SettlementActivity` and updated `Expense/ExpenseSplit` schemas.
- Exposed API endpoints for creating/listing settlement activities and settling shares.
- Adjusted group balance summary logic to include settlement activities.
- Added comprehensive backend unit and API tests for new functionality.

Frontend (Foundation & TODOs due to my current capabilities):
- Created TypeScript interfaces for all new/updated models.
- Set up `listDetailStore.ts` with an action to handle `settleExpenseSplit` (API call is a placeholder) and refresh data.
- Created `SettleShareModal.vue` component for payment confirmation.
- Added unit tests for the new modal and store logic.
- Updated `ListDetailPage.vue` to display detailed expense/share statuses and settlement activities.
- `mitlist_doc.md` updated to reflect all backend changes and current frontend status.
- A `TODO.md` (implicitly within `mitlist_doc.md`'s new section) outlines necessary manual frontend integrations for `api.ts` and `ListDetailPage.vue` to complete the 'Settle Share' UI flow.

This set of changes provides the core backend infrastructure for precise expense share tracking and settlement, and lays the groundwork for full frontend integration.
2025-05-22 07:05:31 +00:00

633 lines
29 KiB
Python

# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
Item as ItemModel,
ExpenseOverallStatusEnum, # Added
ExpenseSplitStatusEnum, # Added
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
# Using existing specific exceptions where possible
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError, # Import the new exception
DatabaseConnectionError, # Added
DatabaseIntegrityError, # Added
DatabaseQueryError, # Added
DatabaseTransactionError,# Added
ExpenseOperationError # Added specific exception
)
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
# pass
logger = logging.getLogger(__name__) # Initialize logger
def _round_money(amount: Decimal) -> Decimal:
"""Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
"""
Determines the list of users an expense should be split amongst.
Priority: Group members (if group_id), then List's group members or creator (if list_id).
Fallback to only the payer if no other context yields users.
"""
users_to_split_with: PyList[UserModel] = []
processed_user_ids = set()
async def _add_user(user: Optional[UserModel]):
if user and user.id not in processed_user_ids:
users_to_split_with.append(user)
processed_user_ids.add(user.id)
if expense_group_id:
group_result = await db.execute(
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
.where(GroupModel.id == expense_group_id)
)
group = group_result.scalars().first()
if not group:
raise GroupNotFoundError(expense_group_id)
for assoc in group.member_associations:
await _add_user(assoc.user)
elif expense_list_id: # Only if group_id was not primary context
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == expense_list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(expense_list_id)
if db_list.group:
for assoc in db_list.group.member_associations:
await _add_user(assoc.user)
elif db_list.creator:
await _add_user(db_list.creator)
if not users_to_split_with:
payer_user = await db.get(UserModel, expense_paid_by_user_id)
if not payer_user:
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
raise UserNotFoundError(user_id=expense_paid_by_user_id)
await _add_user(payer_user)
if not users_to_split_with:
# This should ideally not be reached if payer is always a fallback
raise InvalidOperationError("Could not determine any users for splitting the expense.")
return users_to_split_with
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
"""Creates a new expense and its associated splits.
Args:
db: Database session
expense_in: Expense creation data
current_user_id: ID of the user creating the expense
Returns:
The created expense with splits
Raises:
UserNotFoundError: If payer or split users don't exist
ListNotFoundError: If specified list doesn't exist
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# 1. Validate payer
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
# 2. Context Resolution and Validation (now part of the transaction)
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
final_group_id = await _resolve_expense_context(db, expense_in)
# Further validation for item_id if provided
db_item_instance = None
if expense_in.item_id:
db_item_instance = await db.get(ItemModel, expense_in.item_id)
if not db_item_instance:
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
# Potentially link item's list/group if not already set on expense_in
if db_item_instance.list_id and not expense_in.list_id:
expense_in.list_id = db_item_instance.list_id
# Re-resolve context if list_id was derived from item
final_group_id = await _resolve_expense_context(db, expense_in)
# 3. Create the ExpenseModel instance
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=_round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id, # Use resolved group_id
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id,
overall_settlement_status=ExpenseOverallStatusEnum.unpaid # Explicitly set default status
)
db.add(db_expense)
await db.flush() # Get expense ID
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
splits_to_create = await _generate_expense_splits(
db=db,
expense_model=db_expense,
expense_in=expense_in,
current_user_id=current_user_id # Pass for item-based splits needing creator info
)
for split_model in splits_to_create:
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
db.add_all(splits_to_create)
await db.flush() # Persist splits
# 5. Re-fetch the expense with all necessary relationships for the response
stmt = (
select(ExpenseModel)
.where(ExpenseModel.id == db_expense.id)
.options(
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.created_by_user), # If you have this relationship
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
)
)
result = await db.execute(stmt)
loaded_expense = result.scalar_one_or_none()
if loaded_expense is None:
# The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
# await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are business logic validation errors, re-raise them.
# If a transaction was started, the context manager handles rollback.
raise
except IntegrityError as e:
# Context manager handles rollback.
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
except SQLAlchemyError as e:
# Context manager handles rollback.
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
"""Resolves and validates the expense's context (list and group).
Returns the final group_id for the expense after validation.
"""
final_group_id = expense_in.group_id
# If list_id is provided, validate it and potentially derive group_id
if expense_in.list_id:
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
# If list belongs to a group, verify consistency or inherit group_id
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
f"but expense was specified for group {expense_in.group_id}."
)
final_group_id = list_obj.group_id # Prioritize list's group
# If only group_id is provided (no list_id), validate group_id
elif final_group_id:
group_obj = await db.get(GroupModel, final_group_id)
if not group_obj:
raise GroupNotFoundError(final_group_id)
return final_group_id
async def _generate_expense_splits(
db: AsyncSession,
expense_model: ExpenseModel,
expense_in: ExpenseCreate,
**kwargs: Any
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(**common_args)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
if not splits_to_create:
raise InvalidOperationError("No expense splits were generated.")
return splits_to_create
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
remainder = expense_model.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money_func(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
owed_amount=split_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money_func(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=rounded_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(current_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
)
return splits
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
total_percentage = Decimal("0.00")
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
raise InvalidOperationError(
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
)
total_percentage += split_in.share_percentage
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_percentage=split_in.share_percentage,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for SHARES split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
# Calculate total shares
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
if total_shares == 0:
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
splits = []
current_total = Decimal("0.00")
for split_in in expense_in.splits_in:
if not (split_in.share_units and split_in.share_units > 0):
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_units=split_in.share_units,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_model.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
if expense_model.item_id:
items_query = items_query.where(ItemModel.id == expense_model.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
# Load items with their adders
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
relevant_items = items_result.scalars().all()
if not relevant_items:
error_msg = (
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
if expense_model.item_id else
f"List {expense_model.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
# Aggregate owed amounts by user
calculated_total = Decimal("0.00")
user_owed_amounts = defaultdict(Decimal)
processed_items = 0
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_model.item_id:
raise InvalidOperationError(
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
)
continue
if not item.added_by_user:
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
calculated_total += item.price
user_owed_amounts[item.added_by_user.id] += item.price
processed_items += 1
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money_func(calculated_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Expense total amount ({expense_model.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
# Create splits based on aggregated amounts
splits = []
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money_func(owed_amount),
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
"""Validates that all users in the splits exist."""
user_ids_in_split = [s.user_id for s in splits_in]
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
found_user_ids = {row[0] for row in user_results}
if len(found_user_ids) != len(user_ids_in_split):
missing_user_ids = set(user_ids_in_split) - found_user_ids
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item)
)
.where(ExpenseModel.id == expense_id)
)
return result.scalars().first()
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
)
return result.scalars().all()
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
)
return result.scalars().all()
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
if expense_db.version != expense_in.version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
)
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
# Fields that are safe to update without affecting splits or core logic
allowed_to_update = {"description", "currency", "expense_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(expense_db, field, value)
updated_something = True
else:
# If any other field is present in the update payload, it's an invalid operation for this simple update
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
# No actual updatable fields were provided in the payload, even if others (like version) were.
# This could be a non-issue, or an indication of a misuse of the endpoint.
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
Deletes an expense. Requires version matching if expected_version is provided.
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
"""
if expected_version is not None and expense_db.version != expected_version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(expense_db)
await db.flush() # Ensure the delete operation is sent to the database
except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.