# app/crud/expense.py import logging # Add logging import from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any from datetime import datetime, timezone # Added timezone from app.models import ( Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, User as UserModel, List as ListModel, Group as GroupModel, UserGroup as UserGroupModel, SplitTypeEnum, Item as ItemModel ) from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate from app.core.exceptions import ( # Using existing specific exceptions where possible ListNotFoundError, GroupNotFoundError, UserNotFoundError, InvalidOperationError, # Import the new exception DatabaseConnectionError, # Added DatabaseIntegrityError, # Added DatabaseQueryError, # Added DatabaseTransactionError,# Added ExpenseOperationError # Added specific exception ) # Placeholder for InvalidOperationError if not defined in app.core.exceptions # This should be a proper HTTPException subclass if used in API layer # class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic # pass logger = logging.getLogger(__name__) # Initialize logger def _round_money(amount: Decimal) -> Decimal: """Rounds a Decimal to two decimal places using ROUND_HALF_UP.""" return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]: """ Determines the list of users an expense should be split amongst. Priority: Group members (if group_id), then List's group members or creator (if list_id). Fallback to only the payer if no other context yields users. """ users_to_split_with: PyList[UserModel] = [] processed_user_ids = set() async def _add_user(user: Optional[UserModel]): if user and user.id not in processed_user_ids: users_to_split_with.append(user) processed_user_ids.add(user.id) if expense_group_id: group_result = await db.execute( select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))) .where(GroupModel.id == expense_group_id) ) group = group_result.scalars().first() if not group: raise GroupNotFoundError(expense_group_id) for assoc in group.member_associations: await _add_user(assoc.user) elif expense_list_id: # Only if group_id was not primary context list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), selectinload(ListModel.creator) ) .where(ListModel.id == expense_list_id) ) db_list = list_result.scalars().first() if not db_list: raise ListNotFoundError(expense_list_id) if db_list.group: for assoc in db_list.group.member_associations: await _add_user(assoc.user) elif db_list.creator: await _add_user(db_list.creator) if not users_to_split_with: payer_user = await db.get(UserModel, expense_paid_by_user_id) if not payer_user: # This should have been caught earlier if paid_by_user_id was validated before calling this helper raise UserNotFoundError(user_id=expense_paid_by_user_id) await _add_user(payer_user) if not users_to_split_with: # This should ideally not be reached if payer is always a fallback raise InvalidOperationError("Could not determine any users for splitting the expense.") return users_to_split_with async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel: """Creates a new expense and its associated splits. Args: db: Database session expense_in: Expense creation data current_user_id: ID of the user creating the expense Returns: The created expense with splits Raises: UserNotFoundError: If payer or split users don't exist ListNotFoundError: If specified list doesn't exist GroupNotFoundError: If specified group doesn't exist InvalidOperationError: For various validation failures """ try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # 1. Validate payer payer = await db.get(UserModel, expense_in.paid_by_user_id) if not payer: raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer") # 2. Context Resolution and Validation (now part of the transaction) if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id: raise InvalidOperationError("Expense must be associated with a list, a group, or an item.") final_group_id = await _resolve_expense_context(db, expense_in) # Further validation for item_id if provided db_item_instance = None if expense_in.item_id: db_item_instance = await db.get(ItemModel, expense_in.item_id) if not db_item_instance: raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.") # Potentially link item's list/group if not already set on expense_in if db_item_instance.list_id and not expense_in.list_id: expense_in.list_id = db_item_instance.list_id # Re-resolve context if list_id was derived from item final_group_id = await _resolve_expense_context(db, expense_in) # 3. Create the ExpenseModel instance db_expense = ExpenseModel( description=expense_in.description, total_amount=_round_money(expense_in.total_amount), currency=expense_in.currency or "USD", expense_date=expense_in.expense_date or datetime.now(timezone.utc), split_type=expense_in.split_type, list_id=expense_in.list_id, group_id=final_group_id, # Use resolved group_id item_id=expense_in.item_id, paid_by_user_id=expense_in.paid_by_user_id, created_by_user_id=current_user_id ) db.add(db_expense) await db.flush() # Get expense ID # 4. Generate splits (passing current_user_id through kwargs if needed by specific split types) splits_to_create = await _generate_expense_splits( db=db, expense_model=db_expense, expense_in=expense_in, current_user_id=current_user_id # Pass for item-based splits needing creator info ) for split_model in splits_to_create: split_model.expense_id = db_expense.id # Set FK after db_expense has ID db.add_all(splits_to_create) await db.flush() # Persist splits # 5. Re-fetch the expense with all necessary relationships for the response stmt = ( select(ExpenseModel) .where(ExpenseModel.id == db_expense.id) .options( selectinload(ExpenseModel.paid_by_user), selectinload(ExpenseModel.created_by_user), # If you have this relationship selectinload(ExpenseModel.list), selectinload(ExpenseModel.group), selectinload(ExpenseModel.item), selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user) ) ) result = await db.execute(stmt) loaded_expense = result.scalar_one_or_none() if loaded_expense is None: # The context manager will handle rollback if an exception is raised. # await transaction.rollback() # Should be handled by context manager raise ExpenseOperationError("Failed to load expense after creation.") # await transaction.commit() # Explicit commit removed, context manager handles it. return loaded_expense except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e: # These are business logic validation errors, re-raise them. # If a transaction was started, the context manager handles rollback. raise except IntegrityError as e: # Context manager handles rollback. logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True) raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}") except OperationalError as e: logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}") except SQLAlchemyError as e: # Context manager handles rollback. logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}") async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]: """Resolves and validates the expense's context (list and group). Returns the final group_id for the expense after validation. """ final_group_id = expense_in.group_id # If list_id is provided, validate it and potentially derive group_id if expense_in.list_id: list_obj = await db.get(ListModel, expense_in.list_id) if not list_obj: raise ListNotFoundError(expense_in.list_id) # If list belongs to a group, verify consistency or inherit group_id if list_obj.group_id: if expense_in.group_id and list_obj.group_id != expense_in.group_id: raise InvalidOperationError( f"List {expense_in.list_id} belongs to group {list_obj.group_id}, " f"but expense was specified for group {expense_in.group_id}." ) final_group_id = list_obj.group_id # Prioritize list's group # If only group_id is provided (no list_id), validate group_id elif final_group_id: group_obj = await db.get(GroupModel, final_group_id) if not group_obj: raise GroupNotFoundError(final_group_id) return final_group_id async def _generate_expense_splits( db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, **kwargs: Any ) -> PyList[ExpenseSplitModel]: """Generates appropriate expense splits based on split type.""" splits_to_create: PyList[ExpenseSplitModel] = [] # Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based) common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs} # Create splits based on the split type if expense_in.split_type == SplitTypeEnum.EQUAL: splits_to_create = await _create_equal_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS: splits_to_create = await _create_exact_amount_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.PERCENTAGE: splits_to_create = await _create_percentage_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.SHARES: splits_to_create = await _create_shares_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.ITEM_BASED: splits_to_create = await _create_item_based_splits(**common_args) else: raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}") if not splits_to_create: raise InvalidOperationError("No expense splits were generated.") return splits_to_create async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates equal splits among users.""" users_for_splitting = await get_users_for_splitting( db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id ) if not users_for_splitting: raise InvalidOperationError("No users found for EQUAL split.") num_users = len(users_for_splitting) amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users)) remainder = expense_model.total_amount - (amount_per_user * num_users) splits = [] for i, user in enumerate(users_for_splitting): split_amount = amount_per_user if i == 0 and remainder != Decimal('0'): split_amount = round_money_func(amount_per_user + remainder) splits.append(ExpenseSplitModel( user_id=user.id, owed_amount=split_amount )) return splits async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits with exact amounts.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) current_total = Decimal("0.00") splits = [] for split_in in expense_in.splits_in: if split_in.owed_amount <= Decimal('0'): raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.") rounded_amount = round_money_func(split_in.owed_amount) current_total += rounded_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=rounded_amount )) if round_money_func(current_total) != expense_model.total_amount: raise InvalidOperationError( f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})." ) return splits async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on percentages.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for PERCENTAGE split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) total_percentage = Decimal("0.00") current_total = Decimal("0.00") splits = [] for split_in in expense_in.splits_in: if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")): raise InvalidOperationError( f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}." ) total_percentage += split_in.share_percentage owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100"))) current_total += owed_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=owed_amount, share_percentage=split_in.share_percentage )) if round_money_func(total_percentage) != Decimal("100.00"): raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.") # Adjust for rounding differences if current_total != expense_model.total_amount and splits: diff = expense_model.total_amount - current_total splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff) return splits async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on shares.""" if not expense_in.splits_in: raise InvalidOperationError("Splits data is required for SHARES split type.") # Validate all users in splits exist await _validate_users_in_splits(db, expense_in.splits_in) # Calculate total shares total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0) if total_shares == 0: raise InvalidOperationError("Total shares cannot be zero for SHARES split.") splits = [] current_total = Decimal("0.00") for split_in in expense_in.splits_in: if not (split_in.share_units and split_in.share_units > 0): raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.") share_ratio = Decimal(split_in.share_units) / Decimal(total_shares) owed_amount = round_money_func(expense_model.total_amount * share_ratio) current_total += owed_amount splits.append(ExpenseSplitModel( user_id=split_in.user_id, owed_amount=owed_amount, share_units=split_in.share_units )) # Adjust for rounding differences if current_total != expense_model.total_amount and splits: diff = expense_model.total_amount - current_total splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff) return splits async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on items in a shopping list.""" if not expense_model.list_id: raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.") if expense_in.splits_in: logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.") # Build query to fetch relevant items items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id) if expense_model.item_id: items_query = items_query.where(ItemModel.id == expense_model.item_id) else: items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0"))) # Load items with their adders items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user))) relevant_items = items_result.scalars().all() if not relevant_items: error_msg = ( f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}." if expense_model.item_id else f"List {expense_model.list_id} has no priced items to base the expense on." ) raise InvalidOperationError(error_msg) # Aggregate owed amounts by user calculated_total = Decimal("0.00") user_owed_amounts = defaultdict(Decimal) processed_items = 0 for item in relevant_items: if item.price is None or item.price <= Decimal("0"): if expense_model.item_id: raise InvalidOperationError( f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense." ) continue if not item.added_by_user: logger.error(f"Item ID {item.id} is missing added_by_user relationship.") raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.") calculated_total += item.price user_owed_amounts[item.added_by_user.id] += item.price processed_items += 1 if processed_items == 0: raise InvalidOperationError( f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense." ) # Validate total matches calculated total if round_money_func(calculated_total) != expense_model.total_amount: raise InvalidOperationError( f"Expense total amount ({expense_model.total_amount}) does not match the " f"calculated total from item prices ({calculated_total})." ) # Create splits based on aggregated amounts splits = [] for user_id, owed_amount in user_owed_amounts.items(): splits.append(ExpenseSplitModel( user_id=user_id, owed_amount=round_money_func(owed_amount) )) return splits async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None: """Validates that all users in the splits exist.""" user_ids_in_split = [s.user_id for s in splits_in] user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split))) found_user_ids = {row[0] for row in user_results} if len(found_user_ids) != len(user_ids_in_split): missing_user_ids = set(user_ids_in_split) - found_user_ids raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}") async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]: result = await db.execute( select(ExpenseModel) .options( selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)), selectinload(ExpenseModel.paid_by_user), selectinload(ExpenseModel.list), selectinload(ExpenseModel.group), selectinload(ExpenseModel.item) ) .where(ExpenseModel.id == expense_id) ) return result.scalars().first() async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: result = await db.execute( select(ExpenseModel) .where(ExpenseModel.list_id == list_id) .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split ) return result.scalars().all() async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]: result = await db.execute( select(ExpenseModel) .where(ExpenseModel.group_id == group_id) .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) ) return result.scalars().all() async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel: """ Updates an existing expense. Only allows updates to description, currency, and expense_date to avoid split complexities. Requires version matching for optimistic locking. """ if expense_db.version != expense_in.version: raise InvalidOperationError( f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. " f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set ) update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data # Fields that are safe to update without affecting splits or core logic allowed_to_update = {"description", "currency", "expense_date"} updated_something = False for field, value in update_data.items(): if field in allowed_to_update: setattr(expense_db, field, value) updated_something = True else: # If any other field is present in the update payload, it's an invalid operation for this simple update raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.") if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update): # No actual updatable fields were provided in the payload, even if others (like version) were. # This could be a non-issue, or an indication of a misuse of the endpoint. # For now, if only version was sent, we still increment if it matched. pass # Or raise InvalidOperationError("No updatable fields provided.") try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: expense_db.version += 1 expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp # db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session await db.flush() # Persist changes to the DB and run constraints await db.refresh(expense_db) # Refresh the object from the DB return expense_db except InvalidOperationError: # Re-raise validation errors to be handled by the caller raise except IntegrityError as e: logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True) # The transaction context manager (begin_nested/begin) handles rollback. raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e except SQLAlchemyError as e: # Catch other SQLAlchemy errors logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True) # The transaction context manager (begin_nested/begin) handles rollback. raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e # No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related. async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None: """ Deletes an expense. Requires version matching if expected_version is provided. Associated ExpenseSplits are cascade deleted by the database foreign key constraint. """ if expected_version is not None and expense_db.version != expected_version: raise InvalidOperationError( f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. " f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: await db.delete(expense_db) await db.flush() # Ensure the delete operation is sent to the database except InvalidOperationError: # Re-raise validation errors raise except IntegrityError as e: logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True) # The transaction context manager (begin_nested/begin) handles rollback. raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e except SQLAlchemyError as e: # Catch other SQLAlchemy errors logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True) # The transaction context manager (begin_nested/begin) handles rollback. raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e return None # Note: The InvalidOperationError is a simple ValueError placeholder. # For API endpoints, these should be translated to appropriate HTTPExceptions. # Ensure app.core.exceptions has proper HTTP error classes if needed.