602 lines
26 KiB
Python
602 lines
26 KiB
Python
# app/crud/expense.py
|
|
import logging # Add logging import
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
|
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
|
from datetime import datetime, timezone # Added timezone
|
|
|
|
from app.models import (
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
User as UserModel,
|
|
List as ListModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel,
|
|
SplitTypeEnum,
|
|
Item as ItemModel
|
|
)
|
|
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
|
|
from app.core.exceptions import (
|
|
# Using existing specific exceptions where possible
|
|
ListNotFoundError,
|
|
GroupNotFoundError,
|
|
UserNotFoundError,
|
|
InvalidOperationError, # Import the new exception
|
|
DatabaseConnectionError, # Added
|
|
DatabaseIntegrityError, # Added
|
|
DatabaseQueryError, # Added
|
|
DatabaseTransactionError,# Added
|
|
ExpenseOperationError # Added specific exception
|
|
)
|
|
|
|
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
|
|
# This should be a proper HTTPException subclass if used in API layer
|
|
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
|
|
# pass
|
|
|
|
logger = logging.getLogger(__name__) # Initialize logger
|
|
|
|
async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
|
|
"""
|
|
Determines the list of users an expense should be split amongst.
|
|
Priority: Group members (if group_id), then List's group members or creator (if list_id).
|
|
Fallback to only the payer if no other context yields users.
|
|
"""
|
|
users_to_split_with: PyList[UserModel] = []
|
|
processed_user_ids = set()
|
|
|
|
async def _add_user(user: Optional[UserModel]):
|
|
if user and user.id not in processed_user_ids:
|
|
users_to_split_with.append(user)
|
|
processed_user_ids.add(user.id)
|
|
|
|
if expense_group_id:
|
|
group_result = await db.execute(
|
|
select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
|
|
.where(GroupModel.id == expense_group_id)
|
|
)
|
|
group = group_result.scalars().first()
|
|
if not group:
|
|
raise GroupNotFoundError(expense_group_id)
|
|
for assoc in group.member_associations:
|
|
await _add_user(assoc.user)
|
|
|
|
elif expense_list_id: # Only if group_id was not primary context
|
|
list_result = await db.execute(
|
|
select(ListModel)
|
|
.options(
|
|
selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
|
|
selectinload(ListModel.creator)
|
|
)
|
|
.where(ListModel.id == expense_list_id)
|
|
)
|
|
db_list = list_result.scalars().first()
|
|
if not db_list:
|
|
raise ListNotFoundError(expense_list_id)
|
|
|
|
if db_list.group:
|
|
for assoc in db_list.group.member_associations:
|
|
await _add_user(assoc.user)
|
|
elif db_list.creator:
|
|
await _add_user(db_list.creator)
|
|
|
|
if not users_to_split_with:
|
|
payer_user = await db.get(UserModel, expense_paid_by_user_id)
|
|
if not payer_user:
|
|
# This should have been caught earlier if paid_by_user_id was validated before calling this helper
|
|
raise UserNotFoundError(user_id=expense_paid_by_user_id)
|
|
await _add_user(payer_user)
|
|
|
|
if not users_to_split_with:
|
|
# This should ideally not be reached if payer is always a fallback
|
|
raise InvalidOperationError("Could not determine any users for splitting the expense.")
|
|
|
|
return users_to_split_with
|
|
|
|
|
|
async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
|
|
"""Creates a new expense and its associated splits.
|
|
|
|
Args:
|
|
db: Database session
|
|
expense_in: Expense creation data
|
|
current_user_id: ID of the user creating the expense
|
|
|
|
Returns:
|
|
The created expense with splits
|
|
|
|
Raises:
|
|
UserNotFoundError: If payer or split users don't exist
|
|
ListNotFoundError: If specified list doesn't exist
|
|
GroupNotFoundError: If specified group doesn't exist
|
|
InvalidOperationError: For various validation failures
|
|
"""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
# 1. Validate payer
|
|
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
|
if not payer:
|
|
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
|
|
|
|
# 2. Context Resolution and Validation (now part of the transaction)
|
|
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
|
|
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
|
|
|
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
|
# Further validation for item_id if provided
|
|
db_item_instance = None
|
|
if expense_in.item_id:
|
|
db_item_instance = await db.get(ItemModel, expense_in.item_id)
|
|
if not db_item_instance:
|
|
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
|
|
# Potentially link item's list/group if not already set on expense_in
|
|
if db_item_instance.list_id and not expense_in.list_id:
|
|
expense_in.list_id = db_item_instance.list_id
|
|
# Re-resolve context if list_id was derived from item
|
|
final_group_id = await _resolve_expense_context(db, expense_in)
|
|
|
|
# 3. Create the ExpenseModel instance
|
|
db_expense = ExpenseModel(
|
|
description=expense_in.description,
|
|
total_amount=_round_money(expense_in.total_amount),
|
|
currency=expense_in.currency or "USD",
|
|
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
|
split_type=expense_in.split_type,
|
|
list_id=expense_in.list_id,
|
|
group_id=final_group_id, # Use resolved group_id
|
|
item_id=expense_in.item_id,
|
|
paid_by_user_id=expense_in.paid_by_user_id,
|
|
created_by_user_id=current_user_id
|
|
)
|
|
db.add(db_expense)
|
|
await db.flush() # Get expense ID
|
|
|
|
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
|
|
splits_to_create = await _generate_expense_splits(
|
|
db=db,
|
|
expense_model=db_expense,
|
|
expense_in=expense_in,
|
|
current_user_id=current_user_id # Pass for item-based splits needing creator info
|
|
)
|
|
|
|
for split_model in splits_to_create:
|
|
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
|
|
db.add_all(splits_to_create)
|
|
await db.flush() # Persist splits
|
|
|
|
# 5. Re-fetch the expense with all necessary relationships for the response
|
|
stmt = (
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.id == db_expense.id)
|
|
.options(
|
|
selectinload(ExpenseModel.paid_by_user),
|
|
selectinload(ExpenseModel.created_by_user), # If you have this relationship
|
|
selectinload(ExpenseModel.list),
|
|
selectinload(ExpenseModel.group),
|
|
selectinload(ExpenseModel.item),
|
|
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_expense = result.scalar_one_or_none()
|
|
|
|
if loaded_expense is None:
|
|
await transaction.rollback() # Should be handled by context manager
|
|
raise ExpenseOperationError("Failed to load expense after creation.")
|
|
|
|
await transaction.commit()
|
|
return loaded_expense
|
|
|
|
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
|
# These are business logic validation errors, re-raise them.
|
|
# If a transaction was started, the context manager handles rollback.
|
|
raise
|
|
except IntegrityError as e:
|
|
# Context manager handles rollback.
|
|
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
# Context manager handles rollback.
|
|
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
|
|
|
|
|
|
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
|
|
"""Resolves and validates the expense's context (list and group).
|
|
|
|
Returns the final group_id for the expense after validation.
|
|
"""
|
|
final_group_id = expense_in.group_id
|
|
|
|
# If list_id is provided, validate it and potentially derive group_id
|
|
if expense_in.list_id:
|
|
list_obj = await db.get(ListModel, expense_in.list_id)
|
|
if not list_obj:
|
|
raise ListNotFoundError(expense_in.list_id)
|
|
|
|
# If list belongs to a group, verify consistency or inherit group_id
|
|
if list_obj.group_id:
|
|
if expense_in.group_id and list_obj.group_id != expense_in.group_id:
|
|
raise InvalidOperationError(
|
|
f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
|
|
f"but expense was specified for group {expense_in.group_id}."
|
|
)
|
|
final_group_id = list_obj.group_id # Prioritize list's group
|
|
|
|
# If only group_id is provided (no list_id), validate group_id
|
|
elif final_group_id:
|
|
group_obj = await db.get(GroupModel, final_group_id)
|
|
if not group_obj:
|
|
raise GroupNotFoundError(final_group_id)
|
|
|
|
return final_group_id
|
|
|
|
|
|
async def _generate_expense_splits(
|
|
db: AsyncSession,
|
|
expense_model: ExpenseModel,
|
|
expense_in: ExpenseCreate,
|
|
**kwargs: Any
|
|
) -> PyList[ExpenseSplitModel]:
|
|
"""Generates appropriate expense splits based on split type."""
|
|
|
|
splits_to_create: PyList[ExpenseSplitModel] = []
|
|
|
|
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
|
|
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
|
|
|
|
# Create splits based on the split type
|
|
if expense_in.split_type == SplitTypeEnum.EQUAL:
|
|
splits_to_create = await _create_equal_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
|
|
splits_to_create = await _create_exact_amount_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
|
|
splits_to_create = await _create_percentage_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.SHARES:
|
|
splits_to_create = await _create_shares_splits(**common_args)
|
|
|
|
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
|
|
splits_to_create = await _create_item_based_splits(**common_args)
|
|
|
|
else:
|
|
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
|
|
|
|
if not splits_to_create:
|
|
raise InvalidOperationError("No expense splits were generated.")
|
|
|
|
return splits_to_create
|
|
|
|
|
|
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates equal splits among users."""
|
|
|
|
users_for_splitting = await get_users_for_splitting(
|
|
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
|
|
)
|
|
if not users_for_splitting:
|
|
raise InvalidOperationError("No users found for EQUAL split.")
|
|
|
|
num_users = len(users_for_splitting)
|
|
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
|
|
remainder = expense_model.total_amount - (amount_per_user * num_users)
|
|
|
|
splits = []
|
|
for i, user in enumerate(users_for_splitting):
|
|
split_amount = amount_per_user
|
|
if i == 0 and remainder != Decimal('0'):
|
|
split_amount = round_money_func(amount_per_user + remainder)
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=user.id,
|
|
owed_amount=split_amount
|
|
))
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits with exact amounts."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
current_total = Decimal("0.00")
|
|
splits = []
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if split_in.owed_amount <= Decimal('0'):
|
|
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
|
|
|
|
rounded_amount = round_money_func(split_in.owed_amount)
|
|
current_total += rounded_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=rounded_amount
|
|
))
|
|
|
|
if round_money_func(current_total) != expense_model.total_amount:
|
|
raise InvalidOperationError(
|
|
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
|
|
)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on percentages."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
total_percentage = Decimal("0.00")
|
|
current_total = Decimal("0.00")
|
|
splits = []
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
|
|
raise InvalidOperationError(
|
|
f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
|
|
)
|
|
|
|
total_percentage += split_in.share_percentage
|
|
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
|
|
current_total += owed_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=owed_amount,
|
|
share_percentage=split_in.share_percentage
|
|
))
|
|
|
|
if round_money_func(total_percentage) != Decimal("100.00"):
|
|
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
|
|
|
|
# Adjust for rounding differences
|
|
if current_total != expense_model.total_amount and splits:
|
|
diff = expense_model.total_amount - current_total
|
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on shares."""
|
|
|
|
if not expense_in.splits_in:
|
|
raise InvalidOperationError("Splits data is required for SHARES split type.")
|
|
|
|
# Validate all users in splits exist
|
|
await _validate_users_in_splits(db, expense_in.splits_in)
|
|
|
|
# Calculate total shares
|
|
total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
|
|
if total_shares == 0:
|
|
raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
|
|
|
|
splits = []
|
|
current_total = Decimal("0.00")
|
|
|
|
for split_in in expense_in.splits_in:
|
|
if not (split_in.share_units and split_in.share_units > 0):
|
|
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
|
|
|
|
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
|
|
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
|
|
current_total += owed_amount
|
|
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=split_in.user_id,
|
|
owed_amount=owed_amount,
|
|
share_units=split_in.share_units
|
|
))
|
|
|
|
# Adjust for rounding differences
|
|
if current_total != expense_model.total_amount and splits:
|
|
diff = expense_model.total_amount - current_total
|
|
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
|
|
|
return splits
|
|
|
|
|
|
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
|
"""Creates splits based on items in a shopping list."""
|
|
|
|
if not expense_model.list_id:
|
|
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
|
|
|
|
if expense_in.splits_in:
|
|
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
|
|
|
|
# Build query to fetch relevant items
|
|
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
|
|
if expense_model.item_id:
|
|
items_query = items_query.where(ItemModel.id == expense_model.item_id)
|
|
else:
|
|
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
|
|
|
|
# Load items with their adders
|
|
items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
|
|
relevant_items = items_result.scalars().all()
|
|
|
|
if not relevant_items:
|
|
error_msg = (
|
|
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
|
|
if expense_model.item_id else
|
|
f"List {expense_model.list_id} has no priced items to base the expense on."
|
|
)
|
|
raise InvalidOperationError(error_msg)
|
|
|
|
# Aggregate owed amounts by user
|
|
calculated_total = Decimal("0.00")
|
|
user_owed_amounts = defaultdict(Decimal)
|
|
processed_items = 0
|
|
|
|
for item in relevant_items:
|
|
if item.price is None or item.price <= Decimal("0"):
|
|
if expense_model.item_id:
|
|
raise InvalidOperationError(
|
|
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
|
|
)
|
|
continue
|
|
|
|
if not item.added_by_user:
|
|
logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
|
|
raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
|
|
|
|
calculated_total += item.price
|
|
user_owed_amounts[item.added_by_user.id] += item.price
|
|
processed_items += 1
|
|
|
|
if processed_items == 0:
|
|
raise InvalidOperationError(
|
|
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
|
|
)
|
|
|
|
# Validate total matches calculated total
|
|
if round_money_func(calculated_total) != expense_model.total_amount:
|
|
raise InvalidOperationError(
|
|
f"Expense total amount ({expense_model.total_amount}) does not match the "
|
|
f"calculated total from item prices ({calculated_total})."
|
|
)
|
|
|
|
# Create splits based on aggregated amounts
|
|
splits = []
|
|
for user_id, owed_amount in user_owed_amounts.items():
|
|
splits.append(ExpenseSplitModel(
|
|
user_id=user_id,
|
|
owed_amount=round_money_func(owed_amount)
|
|
))
|
|
|
|
return splits
|
|
|
|
|
|
async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
|
|
"""Validates that all users in the splits exist."""
|
|
|
|
user_ids_in_split = [s.user_id for s in splits_in]
|
|
user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
|
|
found_user_ids = {row[0] for row in user_results}
|
|
|
|
if len(found_user_ids) != len(user_ids_in_split):
|
|
missing_user_ids = set(user_ids_in_split) - found_user_ids
|
|
raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")
|
|
|
|
|
|
async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.options(
|
|
selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
|
|
selectinload(ExpenseModel.paid_by_user),
|
|
selectinload(ExpenseModel.list),
|
|
selectinload(ExpenseModel.group),
|
|
selectinload(ExpenseModel.item)
|
|
)
|
|
.where(ExpenseModel.id == expense_id)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.list_id == list_id)
|
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
|
|
result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.group_id == group_id)
|
|
.order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
|
|
)
|
|
return result.scalars().all()
|
|
|
|
async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
|
|
"""
|
|
Updates an existing expense.
|
|
Only allows updates to description, currency, and expense_date to avoid split complexities.
|
|
Requires version matching for optimistic locking.
|
|
"""
|
|
if expense_db.version != expense_in.version:
|
|
raise InvalidOperationError(
|
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
|
|
f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
|
|
# status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
|
|
)
|
|
|
|
update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
|
|
|
|
# Fields that are safe to update without affecting splits or core logic
|
|
allowed_to_update = {"description", "currency", "expense_date"}
|
|
|
|
updated_something = False
|
|
for field, value in update_data.items():
|
|
if field in allowed_to_update:
|
|
setattr(expense_db, field, value)
|
|
updated_something = True
|
|
else:
|
|
# If any other field is present in the update payload, it's an invalid operation for this simple update
|
|
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")
|
|
|
|
if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
|
|
# No actual updatable fields were provided in the payload, even if others (like version) were.
|
|
# This could be a non-issue, or an indication of a misuse of the endpoint.
|
|
# For now, if only version was sent, we still increment if it matched.
|
|
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
|
|
|
expense_db.version += 1
|
|
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
|
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(expense_db)
|
|
except Exception as e:
|
|
await db.rollback()
|
|
# Consider specific DB error types if needed
|
|
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
|
|
|
|
return expense_db
|
|
|
|
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
|
"""
|
|
Deletes an expense. Requires version matching if expected_version is provided.
|
|
Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
|
|
"""
|
|
if expected_version is not None and expense_db.version != expected_version:
|
|
raise InvalidOperationError(
|
|
f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
|
|
f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
|
|
# status_code=status.HTTP_409_CONFLICT
|
|
)
|
|
|
|
await db.delete(expense_db)
|
|
try:
|
|
await db.commit()
|
|
except Exception as e:
|
|
await db.rollback()
|
|
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
|
|
return None
|
|
|
|
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
|
# For API endpoints, these should be translated to appropriate HTTPExceptions.
|
|
# Ensure app.core.exceptions has proper HTTP error classes if needed. |