mitlist/be/app/crud/expense.py
mohamad 81577ac7e8 feat: Add Recurrence Pattern and Update Expense Schema
- Introduced a new `RecurrencePattern` model to manage recurrence details for expenses, allowing for daily, weekly, monthly, and yearly patterns.
- Updated the `Expense` model to include fields for recurrence management, such as `is_recurring`, `recurrence_pattern_id`, and `next_occurrence`.
- Modified the database schema to reflect these changes, including alterations to existing columns and the removal of obsolete fields.
- Enhanced the expense creation logic to accommodate recurring expenses and updated related CRUD operations accordingly.
- Implemented necessary migrations to ensure database integrity and support for the new features.
2025-05-23 21:01:49 +02:00

652 lines
30 KiB
Python

# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone
from app.models import (
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
User as UserModel,
List as ListModel,
Group as GroupModel,
UserGroup as UserGroupModel,
SplitTypeEnum,
Item as ItemModel,
ExpenseOverallStatusEnum, # Added
ExpenseSplitStatusEnum, # Added
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
# Using existing specific exceptions where possible
ListNotFoundError,
GroupNotFoundError,
UserNotFoundError,
InvalidOperationError, # Import the new exception
DatabaseConnectionError, # Added
DatabaseIntegrityError, # Added
DatabaseQueryError, # Added
DatabaseTransactionError,# Added
ExpenseOperationError # Added specific exception
)
from app.models import RecurrencePattern
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
# pass
logger = logging.getLogger(__name__) # Initialize logger
def _round_money(amount: Decimal) -> Decimal:
"""Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
"""
Determines the list of users an expense should be split amongst.
Priority: Group members (if group_id), then List's group members or creator (if list_id).
Fallback to only the payer if no other context yields users.
"""
users_to_split_with: PyList[UserModel] = []
processed_user_ids = set()
async def _add_user(user: Optional[UserModel]):
if user and user.id not in processed_user_ids:
users_to_split_with.append(user)
processed_user_ids.add(user.id)
if expense_group_id:
group_result = await db.execute(
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
.where(GroupModel.id == expense_group_id)
)
group = group_result.scalars().first()
if not group:
raise GroupNotFoundError(expense_group_id)
for assoc in group.member_associations:
await _add_user(assoc.user)
elif expense_list_id: # Only if group_id was not primary context
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == expense_list_id)
)
db_list = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(expense_list_id)
if db_list.group:
for assoc in db_list.group.member_associations:
await _add_user(assoc.user)
elif db_list.creator:
await _add_user(db_list.creator)
if not users_to_split_with:
payer_user = await db.get(UserModel, expense_paid_by_user_id)
if not payer_user:
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
raise UserNotFoundError(user_id=expense_paid_by_user_id)
await _add_user(payer_user)
if not users_to_split_with:
# This should ideally not be reached if payer is always a fallback
raise InvalidOperationError("Could not determine any users for splitting the expense.")
return users_to_split_with
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
"""Creates a new expense and its associated splits.
Args:
db: Database session
expense_in: Expense creation data
current_user_id: ID of the user creating the expense
Returns:
The created expense with splits
Raises:
UserNotFoundError: If payer or split users don't exist
ListNotFoundError: If specified list doesn't exist
GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# 1. Validate payer
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
# 2. Context Resolution and Validation (now part of the transaction)
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
final_group_id = await _resolve_expense_context(db, expense_in)
# Further validation for item_id if provided
db_item_instance = None
if expense_in.item_id:
db_item_instance = await db.get(ItemModel, expense_in.item_id)
if not db_item_instance:
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
# Potentially link item's list/group if not already set on expense_in
if db_item_instance.list_id and not expense_in.list_id:
expense_in.list_id = db_item_instance.list_id
# Re-resolve context if list_id was derived from item
final_group_id = await _resolve_expense_context(db, expense_in)
# Create recurrence pattern if this is a recurring expense
recurrence_pattern = None
if expense_in.is_recurring and expense_in.recurrence_pattern:
recurrence_pattern = RecurrencePattern(
type=expense_in.recurrence_pattern.type,
interval=expense_in.recurrence_pattern.interval,
days_of_week=expense_in.recurrence_pattern.days_of_week,
end_date=expense_in.recurrence_pattern.end_date,
max_occurrences=expense_in.recurrence_pattern.max_occurrences,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
db.add(recurrence_pattern)
await db.flush()
# 3. Create the ExpenseModel instance
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=_round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id, # Use resolved group_id
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id,
overall_settlement_status=ExpenseOverallStatusEnum.unpaid,
is_recurring=expense_in.is_recurring,
recurrence_pattern=recurrence_pattern,
next_occurrence=expense_in.expense_date if expense_in.is_recurring else None
)
db.add(db_expense)
await db.flush() # Get expense ID
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
splits_to_create = await _generate_expense_splits(
db=db,
expense_model=db_expense,
expense_in=expense_in,
current_user_id=current_user_id # Pass for item-based splits needing creator info
)
for split_model in splits_to_create:
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
db.add_all(splits_to_create)
await db.flush() # Persist splits
# 5. Re-fetch the expense with all necessary relationships for the response
stmt = (
select(ExpenseModel)
.where(ExpenseModel.id == db_expense.id)
.options(
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.created_by_user), # If you have this relationship
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
)
)
result = await db.execute(stmt)
loaded_expense = result.scalar_one_or_none()
if loaded_expense is None:
# The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
# await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are business logic validation errors, re-raise them.
# If a transaction was started, the context manager handles rollback.
raise
except IntegrityError as e:
# Context manager handles rollback.
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
except SQLAlchemyError as e:
# Context manager handles rollback.
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
"""Resolves and validates the expense's context (list and group).
Returns the final group_id for the expense after validation.
"""
final_group_id = expense_in.group_id
# If list_id is provided, validate it and potentially derive group_id
if expense_in.list_id:
list_obj = await db.get(ListModel, expense_in.list_id)
if not list_obj:
raise ListNotFoundError(expense_in.list_id)
# If list belongs to a group, verify consistency or inherit group_id
if list_obj.group_id:
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
raise InvalidOperationError(
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
f"but expense was specified for group {expense_in.group_id}."
)
final_group_id = list_obj.group_id # Prioritize list's group
# If only group_id is provided (no list_id), validate group_id
elif final_group_id:
group_obj = await db.get(GroupModel, final_group_id)
if not group_obj:
raise GroupNotFoundError(final_group_id)
return final_group_id
async def _generate_expense_splits(
db: AsyncSession,
expense_model: ExpenseModel,
expense_in: ExpenseCreate,
**kwargs: Any
) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = []
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
# Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits(**common_args)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits(**common_args)
else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
if not splits_to_create:
raise InvalidOperationError("No expense splits were generated.")
return splits_to_create
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting(
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
)
if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting)
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
remainder = expense_model.total_amount - (amount_per_user * num_users)
splits = []
for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'):
split_amount = round_money_func(amount_per_user + remainder)
splits.append(ExpenseSplitModel(
user_id=user.id,
owed_amount=split_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money_func(split_in.owed_amount)
current_total += rounded_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=rounded_amount,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(current_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
)
return splits
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
total_percentage = Decimal("0.00")
current_total = Decimal("0.00")
splits = []
for split_in in expense_in.splits_in:
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
raise InvalidOperationError(
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
)
total_percentage += split_in.share_percentage
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_percentage=split_in.share_percentage,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
if round_money_func(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares."""
if not expense_in.splits_in:
raise InvalidOperationError("Splits data is required for SHARES split type.")
# Validate all users in splits exist
await _validate_users_in_splits(db, expense_in.splits_in)
# Calculate total shares
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
if total_shares == 0:
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
splits = []
current_total = Decimal("0.00")
for split_in in expense_in.splits_in:
if not (split_in.share_units and split_in.share_units > 0):
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
current_total += owed_amount
splits.append(ExpenseSplitModel(
user_id=split_in.user_id,
owed_amount=owed_amount,
share_units=split_in.share_units,
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
# Adjust for rounding differences
if current_total != expense_model.total_amount and splits:
diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list."""
if not expense_model.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
if expense_model.item_id:
items_query = items_query.where(ItemModel.id == expense_model.item_id)
else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
# Load items with their adders
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
relevant_items = items_result.scalars().all()
if not relevant_items:
error_msg = (
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
if expense_model.item_id else
f"List {expense_model.list_id} has no priced items to base the expense on."
)
raise InvalidOperationError(error_msg)
# Aggregate owed amounts by user
calculated_total = Decimal("0.00")
user_owed_amounts = defaultdict(Decimal)
processed_items = 0
for item in relevant_items:
if item.price is None or item.price <= Decimal("0"):
if expense_model.item_id:
raise InvalidOperationError(
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
)
continue
if not item.added_by_user:
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
calculated_total += item.price
user_owed_amounts[item.added_by_user.id] += item.price
processed_items += 1
if processed_items == 0:
raise InvalidOperationError(
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
)
# Validate total matches calculated total
if round_money_func(calculated_total) != expense_model.total_amount:
raise InvalidOperationError(
f"Expense total amount ({expense_model.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})."
)
# Create splits based on aggregated amounts
splits = []
for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel(
user_id=user_id,
owed_amount=round_money_func(owed_amount),
status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
))
return splits
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
"""Validates that all users in the splits exist."""
user_ids_in_split = [s.user_id for s in splits_in]
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
found_user_ids = {row[0] for row in user_results}
if len(found_user_ids) != len(user_ids_in_split):
missing_user_ids = set(user_ids_in_split) - found_user_ids
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.options(
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item)
)
.where(ExpenseModel.id == expense_id)
)
return result.scalars().first()
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.list_id == list_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
)
return result.scalars().all()
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
)
return result.scalars().all()
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
"""
Updates an existing expense.
Only allows updates to description, currency, and expense_date to avoid split complexities.
Requires version matching for optimistic locking.
"""
if expense_db.version != expense_in.version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
)
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
# Fields that are safe to update without affecting splits or core logic
allowed_to_update = {"description", "currency", "expense_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(expense_db, field, value)
updated_something = True
else:
# If any other field is present in the update payload, it's an invalid operation for this simple update
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
# No actual updatable fields were provided in the payload, even if others (like version) were.
# This could be a non-issue, or an indication of a misuse of the endpoint.
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
Deletes an expense. Requires version matching if expected_version is provided.
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
"""
if expected_version is not None and expense_db.version != expected_version:
raise InvalidOperationError(
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(expense_db)
await db.flush() # Ensure the delete operation is sent to the database
except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.