From 7a88ea258a7cadad51580b1d050c69215ec17c8f Mon Sep 17 00:00:00 2001 From: mohamad Date: Fri, 16 May 2025 21:54:29 +0200 Subject: [PATCH] Refactor database session management and exception handling across CRUD operations; streamline transaction handling in expense, group, invite, item, list, settlement, and user modules for improved reliability and clarity. Introduce specific operation errors for better error reporting. --- be/app/core/exceptions.py | 40 +++++ be/app/crud/expense.py | 272 +++++++++++++++--------------- be/app/crud/group.py | 84 +++++++--- be/app/crud/invite.py | 156 ++++++++++++----- be/app/crud/item.py | 168 ++++++++++++------- be/app/crud/list.py | 167 ++++++++++++------- be/app/crud/settlement.py | 343 +++++++++++++++++++++++++------------- be/app/crud/user.py | 67 ++++++-- be/app/database.py | 12 +- 9 files changed, 850 insertions(+), 459 deletions(-) diff --git a/be/app/core/exceptions.py b/be/app/core/exceptions.py index 14911ac..6bae250 100644 --- a/be/app/core/exceptions.py +++ b/be/app/core/exceptions.py @@ -128,6 +128,14 @@ class DatabaseQueryError(HTTPException): detail=detail ) +class ExpenseOperationError(HTTPException): + """Raised when an expense operation fails.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail + ) + class OCRServiceUnavailableError(HTTPException): """Raised when the OCR service is unavailable.""" def __init__(self): @@ -240,6 +248,22 @@ class ListStatusNotFoundError(HTTPException): detail=f"Status for list {list_id} not found" ) +class InviteOperationError(HTTPException): + """Raised when an invite operation fails.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail + ) + +class SettlementOperationError(HTTPException): + """Raised when a settlement operation fails.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail + ) + class ConflictError(HTTPException): """Raised when an optimistic lock version conflict occurs.""" def __init__(self, detail: str): @@ -282,4 +306,20 @@ class JWTUnexpectedError(HTTPException): status_code=status.HTTP_401_UNAUTHORIZED, detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} + ) + +class ListOperationError(HTTPException): + """Raised when a list operation fails.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail + ) + +class ItemOperationError(HTTPException): + """Raised when an item operation fails.""" + def __init__(self, detail: str): + super().__init__( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail ) \ No newline at end of file diff --git a/be/app/crud/expense.py b/be/app/crud/expense.py index 5a5335f..43c58d4 100644 --- a/be/app/crud/expense.py +++ b/be/app/crud/expense.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation -from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict +from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any from datetime import datetime, timezone # Added timezone from app.models import ( @@ -23,7 +23,12 @@ from app.core.exceptions import ( ListNotFoundError, GroupNotFoundError, UserNotFoundError, - InvalidOperationError # Import the new exception + InvalidOperationError, # Import the new exception + DatabaseConnectionError, # Added + DatabaseIntegrityError, # Added + DatabaseQueryError, # Added + DatabaseTransactionError,# Added + ExpenseOperationError # Added specific exception ) # Placeholder for InvalidOperationError if not defined in app.core.exceptions @@ -108,60 +113,97 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us GroupNotFoundError: If specified group doesn't exist InvalidOperationError: For various validation failures """ - # Helper function to round decimals consistently - def round_money(amount: Decimal) -> Decimal: - return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) - - # 1. Context Validation - # Validate basic context requirements first - if not expense_in.list_id and not expense_in.group_id: - raise InvalidOperationError("Expense must be associated with a list or a group.") - - # 2. User Validation - payer = await db.get(UserModel, expense_in.paid_by_user_id) - if not payer: - raise UserNotFoundError(user_id=expense_in.paid_by_user_id) - - # 3. List/Group Context Resolution - final_group_id = await _resolve_expense_context(db, expense_in) - - # 4. Create the expense object - db_expense = ExpenseModel( - description=expense_in.description, - total_amount=round_money(expense_in.total_amount), - currency=expense_in.currency or "USD", - expense_date=expense_in.expense_date or datetime.now(timezone.utc), - split_type=expense_in.split_type, - list_id=expense_in.list_id, - group_id=final_group_id, - item_id=expense_in.item_id, - paid_by_user_id=expense_in.paid_by_user_id, - created_by_user_id=current_user_id # Track who created this expense - ) - - # 5. Generate splits based on split type - splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money) - - # 6. Single transaction for expense and all splits try: - db.add(db_expense) - await db.flush() # Get expense ID without committing - - # Update all splits with the expense ID - for split in splits_to_create: - split.expense_id = db_expense.id - - db.add_all(splits_to_create) - await db.commit() - - except Exception as e: - await db.rollback() - logger.error(f"Failed to save expense: {str(e)}", exc_info=True) - raise InvalidOperationError(f"Failed to save expense: {str(e)}") - - # Refresh to get the splits relationship populated - await db.refresh(db_expense, attribute_names=["splits"]) - return db_expense + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + # 1. Validate payer + payer = await db.get(UserModel, expense_in.paid_by_user_id) + if not payer: + raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer") + + # 2. Context Resolution and Validation (now part of the transaction) + if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id: + raise InvalidOperationError("Expense must be associated with a list, a group, or an item.") + + final_group_id = await _resolve_expense_context(db, expense_in) + # Further validation for item_id if provided + db_item_instance = None + if expense_in.item_id: + db_item_instance = await db.get(ItemModel, expense_in.item_id) + if not db_item_instance: + raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.") + # Potentially link item's list/group if not already set on expense_in + if db_item_instance.list_id and not expense_in.list_id: + expense_in.list_id = db_item_instance.list_id + # Re-resolve context if list_id was derived from item + final_group_id = await _resolve_expense_context(db, expense_in) + + # 3. Create the ExpenseModel instance + db_expense = ExpenseModel( + description=expense_in.description, + total_amount=_round_money(expense_in.total_amount), + currency=expense_in.currency or "USD", + expense_date=expense_in.expense_date or datetime.now(timezone.utc), + split_type=expense_in.split_type, + list_id=expense_in.list_id, + group_id=final_group_id, # Use resolved group_id + item_id=expense_in.item_id, + paid_by_user_id=expense_in.paid_by_user_id, + created_by_user_id=current_user_id + ) + db.add(db_expense) + await db.flush() # Get expense ID + + # 4. Generate splits (passing current_user_id through kwargs if needed by specific split types) + splits_to_create = await _generate_expense_splits( + db=db, + expense_model=db_expense, + expense_in=expense_in, + current_user_id=current_user_id # Pass for item-based splits needing creator info + ) + + for split_model in splits_to_create: + split_model.expense_id = db_expense.id # Set FK after db_expense has ID + db.add_all(splits_to_create) + await db.flush() # Persist splits + + # 5. Re-fetch the expense with all necessary relationships for the response + stmt = ( + select(ExpenseModel) + .where(ExpenseModel.id == db_expense.id) + .options( + selectinload(ExpenseModel.paid_by_user), + selectinload(ExpenseModel.created_by_user), # If you have this relationship + selectinload(ExpenseModel.list), + selectinload(ExpenseModel.group), + selectinload(ExpenseModel.item), + selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user) + ) + ) + result = await db.execute(stmt) + loaded_expense = result.scalar_one_or_none() + + if loaded_expense is None: + await transaction.rollback() # Should be handled by context manager + raise ExpenseOperationError("Failed to load expense after creation.") + + await transaction.commit() + return loaded_expense + + except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e: + # These are business logic validation errors, re-raise them. + # If a transaction was started, the context manager handles rollback. + raise + except IntegrityError as e: + # Context manager handles rollback. + logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True) + raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}") + except OperationalError as e: + logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True) + raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}") + except SQLAlchemyError as e: + # Context manager handles rollback. + logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True) + raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}") async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]: @@ -197,39 +239,32 @@ async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) async def _generate_expense_splits( db: AsyncSession, - db_expense: ExpenseModel, + expense_model: ExpenseModel, expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] + **kwargs: Any ) -> PyList[ExpenseSplitModel]: """Generates appropriate expense splits based on split type.""" splits_to_create: PyList[ExpenseSplitModel] = [] + # Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based) + common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs} + # Create splits based on the split type if expense_in.split_type == SplitTypeEnum.EQUAL: - splits_to_create = await _create_equal_splits( - db, db_expense, expense_in, round_money - ) + splits_to_create = await _create_equal_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS: - splits_to_create = await _create_exact_amount_splits( - db, db_expense, expense_in, round_money - ) + splits_to_create = await _create_exact_amount_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.PERCENTAGE: - splits_to_create = await _create_percentage_splits( - db, db_expense, expense_in, round_money - ) + splits_to_create = await _create_percentage_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.SHARES: - splits_to_create = await _create_shares_splits( - db, db_expense, expense_in, round_money - ) + splits_to_create = await _create_shares_splits(**common_args) elif expense_in.split_type == SplitTypeEnum.ITEM_BASED: - splits_to_create = await _create_item_based_splits( - db, db_expense, expense_in, round_money - ) + splits_to_create = await _create_item_based_splits(**common_args) else: raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}") @@ -240,29 +275,24 @@ async def _generate_expense_splits( return splits_to_create -async def _create_equal_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: +async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates equal splits among users.""" users_for_splitting = await get_users_for_splitting( - db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id + db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id ) if not users_for_splitting: raise InvalidOperationError("No users found for EQUAL split.") num_users = len(users_for_splitting) - amount_per_user = round_money(db_expense.total_amount / Decimal(num_users)) - remainder = db_expense.total_amount - (amount_per_user * num_users) + amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users)) + remainder = expense_model.total_amount - (amount_per_user * num_users) splits = [] for i, user in enumerate(users_for_splitting): split_amount = amount_per_user if i == 0 and remainder != Decimal('0'): - split_amount = round_money(amount_per_user + remainder) + split_amount = round_money_func(amount_per_user + remainder) splits.append(ExpenseSplitModel( user_id=user.id, @@ -272,12 +302,7 @@ async def _create_equal_splits( return splits -async def _create_exact_amount_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: +async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits with exact amounts.""" if not expense_in.splits_in: @@ -293,7 +318,7 @@ async def _create_exact_amount_splits( if split_in.owed_amount <= Decimal('0'): raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.") - rounded_amount = round_money(split_in.owed_amount) + rounded_amount = round_money_func(split_in.owed_amount) current_total += rounded_amount splits.append(ExpenseSplitModel( @@ -301,20 +326,15 @@ async def _create_exact_amount_splits( owed_amount=rounded_amount )) - if round_money(current_total) != db_expense.total_amount: + if round_money_func(current_total) != expense_model.total_amount: raise InvalidOperationError( - f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})." + f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})." ) return splits -async def _create_percentage_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: +async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on percentages.""" if not expense_in.splits_in: @@ -334,7 +354,7 @@ async def _create_percentage_splits( ) total_percentage += split_in.share_percentage - owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100"))) + owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100"))) current_total += owed_amount splits.append(ExpenseSplitModel( @@ -343,23 +363,18 @@ async def _create_percentage_splits( share_percentage=split_in.share_percentage )) - if round_money(total_percentage) != Decimal("100.00"): + if round_money_func(total_percentage) != Decimal("100.00"): raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.") # Adjust for rounding differences - if current_total != db_expense.total_amount and splits: - diff = db_expense.total_amount - current_total - splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) + if current_total != expense_model.total_amount and splits: + diff = expense_model.total_amount - current_total + splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff) return splits -async def _create_shares_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: +async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on shares.""" if not expense_in.splits_in: @@ -381,7 +396,7 @@ async def _create_shares_splits( raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.") share_ratio = Decimal(split_in.share_units) / Decimal(total_shares) - owed_amount = round_money(db_expense.total_amount * share_ratio) + owed_amount = round_money_func(expense_model.total_amount * share_ratio) current_total += owed_amount splits.append(ExpenseSplitModel( @@ -391,31 +406,26 @@ async def _create_shares_splits( )) # Adjust for rounding differences - if current_total != db_expense.total_amount and splits: - diff = db_expense.total_amount - current_total - splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) + if current_total != expense_model.total_amount and splits: + diff = expense_model.total_amount - current_total + splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff) return splits -async def _create_item_based_splits( - db: AsyncSession, - db_expense: ExpenseModel, - expense_in: ExpenseCreate, - round_money: Callable[[Decimal], Decimal] -) -> PyList[ExpenseSplitModel]: +async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]: """Creates splits based on items in a shopping list.""" - if not expense_in.list_id: + if not expense_model.list_id: raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.") if expense_in.splits_in: logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.") # Build query to fetch relevant items - items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id) - if expense_in.item_id: - items_query = items_query.where(ItemModel.id == expense_in.item_id) + items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id) + if expense_model.item_id: + items_query = items_query.where(ItemModel.id == expense_model.item_id) else: items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0"))) @@ -425,9 +435,9 @@ async def _create_item_based_splits( if not relevant_items: error_msg = ( - f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}." - if expense_in.item_id else - f"List {expense_in.list_id} has no priced items to base the expense on." + f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}." + if expense_model.item_id else + f"List {expense_model.list_id} has no priced items to base the expense on." ) raise InvalidOperationError(error_msg) @@ -438,9 +448,9 @@ async def _create_item_based_splits( for item in relevant_items: if item.price is None or item.price <= Decimal("0"): - if expense_in.item_id: + if expense_model.item_id: raise InvalidOperationError( - f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense." + f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense." ) continue @@ -454,13 +464,13 @@ async def _create_item_based_splits( if processed_items == 0: raise InvalidOperationError( - f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense." + f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense." ) # Validate total matches calculated total - if round_money(calculated_total) != db_expense.total_amount: + if round_money_func(calculated_total) != expense_model.total_amount: raise InvalidOperationError( - f"Expense total amount ({db_expense.total_amount}) does not match the " + f"Expense total amount ({expense_model.total_amount}) does not match the " f"calculated total from item prices ({calculated_total})." ) @@ -469,7 +479,7 @@ async def _create_item_based_splits( for user_id, owed_amount in user_owed_amounts.items(): splits.append(ExpenseSplitModel( user_id=user_id, - owed_amount=round_money(owed_amount) + owed_amount=round_money_func(owed_amount) )) return splits diff --git a/be/app/crud/group.py b/be/app/crud/group.py index 8a6ebc6..2617f7b 100644 --- a/be/app/crud/group.py +++ b/be/app/crud/group.py @@ -4,7 +4,7 @@ from sqlalchemy.future import select from sqlalchemy.orm import selectinload # For eager loading members from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List -from sqlalchemy import func +from sqlalchemy import delete, func from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.schemas.group import GroupCreate @@ -24,10 +24,23 @@ from app.core.exceptions import ( async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: """Creates a group and adds the creator as the owner.""" try: - async with db.begin(): + # Defensive check: if a transaction is already active, try to roll it back. + # This is unusual and suggests an issue upstream (e.g., middleware or session configuration). + if db.in_transaction(): + # Log this occurrence if possible, as it's unexpected. + # import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.") + try: + await db.rollback() # Attempt to clear any existing transaction + except SQLAlchemyError as e_rb: + # Log e_rb if possible + # import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}") + # Re-raise or handle as a critical error, as the session state is uncertain. + raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}") + + async with db.begin(): # Now attempt to start a clean transaction db_group = GroupModel(name=group_in.name, created_by_id=creator_id) db.add(db_group) - await db.flush() + await db.flush() # Assigns ID to db_group db_user_group = UserGroupModel( user_id=creator_id, @@ -35,15 +48,30 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) role=UserRoleEnum.owner ) db.add(db_user_group) - await db.flush() - await db.refresh(db_group) - return db_group + await db.flush() # Commits user_group, links to group + + # After creation and linking, explicitly load the group with its member associations and users + stmt = ( + select(GroupModel) + .where(GroupModel.id == db_group.id) + .options( + selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user) + ) + ) + result = await db.execute(stmt) + loaded_group = result.scalar_one_or_none() + + if loaded_group is None: + # This should not happen if we just created it, but as a safeguard + raise GroupOperationError("Failed to load group after creation.") + + return loaded_group except IntegrityError as e: - raise DatabaseIntegrityError(f"Failed to create group: {str(e)}") + raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}") except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}") except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to create group: {str(e)}") + raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}") async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: """Gets all groups a user is a member of.""" @@ -52,7 +80,9 @@ async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: select(GroupModel) .join(UserGroupModel) .where(UserGroupModel.user_id == user_id) - .options(selectinload(GroupModel.member_associations)) + .options( + selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user) + ) ) return result.scalars().all() except OperationalError as e: @@ -106,18 +136,34 @@ async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]: """Adds a user to a group if they aren't already a member.""" try: - async with db.begin(): - existing = await db.execute( - select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) - ) - if existing.scalar_one_or_none(): - return None + # Check if user is already a member before starting a transaction + existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) + existing_result = await db.execute(existing_stmt) + if existing_result.scalar_one_or_none(): + return None + # Use a single transaction + async with db.begin(): db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) db.add(db_user_group) - await db.flush() - await db.refresh(db_user_group) - return db_user_group + await db.flush() # Assigns ID to db_user_group + + # Eagerly load the 'user' and 'group' relationships for the response + stmt = ( + select(UserGroupModel) + .where(UserGroupModel.id == db_user_group.id) + .options( + selectinload(UserGroupModel.user), + selectinload(UserGroupModel.group) + ) + ) + result = await db.execute(stmt) + loaded_user_group = result.scalar_one_or_none() + + if loaded_user_group is None: + raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.") + + return loaded_user_group except IntegrityError as e: raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}") except OperationalError as e: diff --git a/be/app/crud/invite.py b/be/app/crud/invite.py index ed64438..8e5ae09 100644 --- a/be/app/crud/invite.py +++ b/be/app/crud/invite.py @@ -3,10 +3,19 @@ import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy import delete # Import delete statement +from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from typing import Optional -from app.models import Invite as InviteModel +from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload +from app.core.exceptions import ( + DatabaseConnectionError, + DatabaseIntegrityError, + DatabaseQueryError, + DatabaseTransactionError, + InviteOperationError # Add new specific exception +) # Invite codes should be reasonably unique, but handle potential collision MAX_CODE_GENERATION_ATTEMPTS = 5 @@ -14,56 +23,121 @@ MAX_CODE_GENERATION_ATTEMPTS = 5 async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]: """Creates a new invite code for a group.""" expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) - code = None - attempts = 0 - - # Generate a unique code, retrying if a collision occurs (highly unlikely but safe) - while attempts < MAX_CODE_GENERATION_ATTEMPTS: - attempts += 1 + + potential_code = None + for attempt in range(MAX_CODE_GENERATION_ATTEMPTS): potential_code = secrets.token_urlsafe(16) - # Check if an *active* invite with this code already exists - existing = await db.execute( - select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) - ) - if existing.scalar_one_or_none() is None: - code = potential_code - break + # Check if an *active* invite with this code already exists (outside main transaction for now) + # Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError. + # This check reduces collision chance before attempting transaction. + existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) + existing_result = await db.execute(existing_check_stmt) + if existing_result.scalar_one_or_none() is None: + break # Found a potentially unique code + if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1: + raise InviteOperationError("Failed to generate a unique invite code after several attempts.") - if code is None: - # Failed to generate a unique code after several attempts - return None + try: + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + # Final check within transaction to be absolutely sure before insert + final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) + final_check_result = await db.execute(final_check_stmt) + if final_check_result.scalar_one_or_none() is not None: + # Extremely unlikely if previous check passed, but handles race condition + await transaction.rollback() + raise InviteOperationError("Invite code collision detected during transaction.") - db_invite = InviteModel( - code=code, - group_id=group_id, - created_by_id=creator_id, - expires_at=expires_at, - is_active=True - ) - db.add(db_invite) - await db.commit() - await db.refresh(db_invite) - return db_invite + db_invite = InviteModel( + code=potential_code, + group_id=group_id, + created_by_id=creator_id, + expires_at=expires_at, + is_active=True + ) + db.add(db_invite) + await db.flush() # Assigns ID + + # Re-fetch with relationships + stmt = ( + select(InviteModel) + .where(InviteModel.id == db_invite.id) + .options( + selectinload(InviteModel.group), + selectinload(InviteModel.creator) + ) + ) + result = await db.execute(stmt) + loaded_invite = result.scalar_one_or_none() + + if loaded_invite is None: + await transaction.rollback() + raise InviteOperationError("Failed to load invite after creation.") + + await transaction.commit() + return loaded_invite + except IntegrityError as e: # Catch if DB unique constraint on code is violated + # Rollback handled by context manager + raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}") + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}") async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: """Gets an active and non-expired invite by its code.""" now = datetime.now(timezone.utc) - result = await db.execute( - select(InviteModel).where( - InviteModel.code == code, - InviteModel.is_active == True, - InviteModel.expires_at > now + try: + stmt = ( + select(InviteModel).where( + InviteModel.code == code, + InviteModel.is_active == True, + InviteModel.expires_at > now + ) + .options( + selectinload(InviteModel.group), + selectinload(InviteModel.creator) + ) ) - ) - return result.scalars().first() + result = await db.execute(stmt) + return result.scalars().first() + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}") async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: - """Marks an invite as inactive (used).""" - invite.is_active = False - db.add(invite) # Add to session to track change - await db.commit() - await db.refresh(invite) - return invite + """Marks an invite as inactive (used) and reloads with relationships.""" + try: + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + invite.is_active = False + db.add(invite) # Add to session to track change + await db.flush() # Persist is_active change + + # Re-fetch with relationships + stmt = ( + select(InviteModel) + .where(InviteModel.id == invite.id) + .options( + selectinload(InviteModel.group), + selectinload(InviteModel.creator) + ) + ) + result = await db.execute(stmt) + updated_invite = result.scalar_one_or_none() + + if updated_invite is None: # Should not happen as invite is passed in + await transaction.rollback() + raise InviteOperationError("Failed to load invite after deactivation.") + + await transaction.commit() + return updated_invite + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}") + +# Ensure InviteOperationError is defined in app.core.exceptions +# Example: class InviteOperationError(AppException): pass # Optional: Function to periodically delete old, inactive invites # async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... \ No newline at end of file diff --git a/be/app/crud/item.py b/be/app/crud/item.py index 6fb7456..31e9e80 100644 --- a/be/app/crud/item.py +++ b/be/app/crud/item.py @@ -1,12 +1,13 @@ # app/crud/item.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List as PyList from datetime import datetime, timezone -from app.models import Item as ItemModel +from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload from app.schemas.item import ItemCreate, ItemUpdate from app.core.exceptions import ( ItemNotFoundError, @@ -14,46 +15,65 @@ from app.core.exceptions import ( DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, - ConflictError + ConflictError, + ItemOperationError # Add if specific item operation errors are needed ) async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: """Creates a new item record for a specific list.""" try: - db_item = ItemModel( - name=item_in.name, - quantity=item_in.quantity, - list_id=list_id, - added_by_id=user_id, - is_complete=False # Default on creation - # version is implicitly set to 1 by model default - ) - db.add(db_item) - await db.flush() - await db.refresh(db_item) - await db.commit() # Explicitly commit here - return db_item + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + db_item = ItemModel( + name=item_in.name, + quantity=item_in.quantity, + list_id=list_id, + added_by_id=user_id, + is_complete=False + ) + db.add(db_item) + await db.flush() # Assigns ID + + # Re-fetch with relationships + stmt = ( + select(ItemModel) + .where(ItemModel.id == db_item.id) + .options( + selectinload(ItemModel.added_by_user), + selectinload(ItemModel.completed_by_user) # Will be None but loads relationship + ) + ) + result = await db.execute(stmt) + loaded_item = result.scalar_one_or_none() + + if loaded_item is None: + await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity + raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError + + await transaction.commit() + return loaded_item except IntegrityError as e: - await db.rollback() # Rollback on integrity error + # Context manager handles rollback on error raise DatabaseIntegrityError(f"Failed to create item: {str(e)}") except OperationalError as e: - await db.rollback() # Rollback on operational error raise DatabaseConnectionError(f"Database connection error: {str(e)}") except SQLAlchemyError as e: - await db.rollback() # Rollback on other SQLAlchemy errors raise DatabaseTransactionError(f"Failed to create item: {str(e)}") - except Exception as e: # Catch any other exception and attempt rollback - await db.rollback() - raise # Re-raise the original exception + # Removed generic Exception block as SQLAlchemyError should cover DB issues, + # and context manager handles rollback. async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]: """Gets all items belonging to a specific list, ordered by creation time.""" try: - result = await db.execute( + stmt = ( select(ItemModel) .where(ItemModel.list_id == list_id) - .order_by(ItemModel.created_at.asc()) # Or desc() if preferred + .options( + selectinload(ItemModel.added_by_user), + selectinload(ItemModel.completed_by_user) + ) + .order_by(ItemModel.created_at.asc()) ) + result = await db.execute(stmt) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") @@ -63,7 +83,16 @@ async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemMod async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: """Gets a single item by its ID.""" try: - result = await db.execute(select(ItemModel).where(ItemModel.id == item_id)) + stmt = ( + select(ItemModel) + .where(ItemModel.id == item_id) + .options( + selectinload(ItemModel.added_by_user), + selectinload(ItemModel.completed_by_user), + selectinload(ItemModel.list) # Often useful to get the parent list + ) + ) + result = await db.execute(stmt) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") @@ -73,59 +102,72 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: """Updates an existing item record, checking for version conflicts.""" try: - # Check version conflict - if item_db.version != item_in.version: - raise ConflictError( - f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " - f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + if item_db.version != item_in.version: + # No need to rollback here, as the transaction hasn't committed. + # The context manager will handle rollback if an exception is raised. + raise ConflictError( + f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " + f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." + ) + + update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) + + if 'is_complete' in update_data: + if update_data['is_complete'] is True: + if item_db.completed_by_id is None: + update_data['completed_by_id'] = user_id + else: + update_data['completed_by_id'] = None + + for key, value in update_data.items(): + setattr(item_db, key, value) + + item_db.version += 1 + db.add(item_db) # Mark as dirty + await db.flush() + + # Re-fetch with relationships + stmt = ( + select(ItemModel) + .where(ItemModel.id == item_db.id) + .options( + selectinload(ItemModel.added_by_user), + selectinload(ItemModel.completed_by_user), + selectinload(ItemModel.list) + ) ) + result = await db.execute(stmt) + updated_item = result.scalar_one_or_none() - update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version + if updated_item is None: # Should not happen + # Rollback will be handled by context manager on raise + raise ItemOperationError("Failed to load item after update.") - # Special handling for is_complete - if 'is_complete' in update_data: - if update_data['is_complete'] is True: - if item_db.completed_by_id is None: # Only set if not already completed by someone - update_data['completed_by_id'] = user_id - else: - update_data['completed_by_id'] = None # Clear if marked incomplete - - # Apply updates - for key, value in update_data.items(): - setattr(item_db, key, value) - - item_db.version += 1 # Increment version - - db.add(item_db) - await db.flush() - await db.refresh(item_db) - - # Commit the transaction if not part of a larger transaction - await db.commit() - - return item_db + await transaction.commit() + return updated_item except IntegrityError as e: - await db.rollback() raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") except OperationalError as e: - await db.rollback() raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") - except ConflictError: # Re-raise ConflictError - await db.rollback() + except ConflictError: # Re-raise ConflictError, rollback handled by context manager raise except SQLAlchemyError as e: - await db.rollback() raise DatabaseTransactionError(f"Failed to update item: {str(e)}") async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: """Deletes an item record. Version check should be done by the caller (API endpoint).""" try: - await db.delete(item_db) - await db.commit() - return None + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + await db.delete(item_db) + await transaction.commit() + # No return needed for None except OperationalError as e: - await db.rollback() + # Rollback handled by context manager raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") except SQLAlchemyError as e: - await db.rollback() - raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") \ No newline at end of file + # Rollback handled by context manager + raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") + +# Ensure ItemOperationError is defined in app.core.exceptions if used +# Example: class ItemOperationError(AppException): pass \ No newline at end of file diff --git a/be/app/crud/list.py b/be/app/crud/list.py index 35527c8..51ac556 100644 --- a/be/app/crud/list.py +++ b/be/app/crud/list.py @@ -17,15 +17,14 @@ from app.core.exceptions import ( DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, - ConflictError + ConflictError, + ListOperationError ) async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: """Creates a new list record.""" try: - # Check if we're already in a transaction - if db.in_transaction(): - # If we're already in a transaction, just create the list + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: db_list = ListModel( name=list_in.name, description=list_in.description, @@ -34,23 +33,27 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> is_complete=False ) db.add(db_list) - await db.flush() - await db.refresh(db_list) - return db_list - else: - # If no transaction is active, start one - async with db.begin(): - db_list = ListModel( - name=list_in.name, - description=list_in.description, - group_id=list_in.group_id, - created_by_id=creator_id, - is_complete=False + await db.flush() # Assigns ID + + # Re-fetch with relationships for the response + stmt = ( + select(ListModel) + .where(ListModel.id == db_list.id) + .options( + selectinload(ListModel.creator), + selectinload(ListModel.group) + # selectinload(ListModel.items) # Optionally add if items are always needed in response ) - db.add(db_list) - await db.flush() - await db.refresh(db_list) - return db_list + ) + result = await db.execute(stmt) + loaded_list = result.scalar_one_or_none() + + if loaded_list is None: + await transaction.rollback() + raise ListOperationError("Failed to load list after creation.") + + await transaction.commit() + return loaded_list except IntegrityError as e: raise DatabaseIntegrityError(f"Failed to create list: {str(e)}") except OperationalError as e: @@ -66,14 +69,22 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel ) user_group_ids = group_ids_result.scalars().all() - # Build conditions for the OR clause dynamically conditions = [ and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)) ] - if user_group_ids: # Only add the IN clause if there are group IDs + if user_group_ids: conditions.append(ListModel.group_id.in_(user_group_ids)) - query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc()) + query = ( + select(ListModel) + .where(or_(*conditions)) + .options( + selectinload(ListModel.creator), + selectinload(ListModel.group) + # selectinload(ListModel.items) # Consider if items are needed for list previews + ) + .order_by(ListModel.updated_at.desc()) + ) result = await db.execute(query) return result.scalars().all() @@ -85,11 +96,17 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]: """Gets a single list by ID, optionally loading its items.""" try: - query = select(ListModel).where(ListModel.id == list_id) + query = ( + select(ListModel) + .where(ListModel.id == list_id) + .options( + selectinload(ListModel.creator), + selectinload(ListModel.group) + ) + ) if load_items: query = query.options( - selectinload(ListModel.items) - .options( + selectinload(ListModel.items).options( joinedload(ItemModel.added_by_user), joinedload(ItemModel.completed_by_user) ) @@ -104,8 +121,9 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: """Updates an existing list record, checking for version conflicts.""" try: - async with db.begin(): - if list_db.version != list_in.version: + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer + await transaction.rollback() # Rollback before raising raise ConflictError( f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." @@ -118,34 +136,54 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) list_db.version += 1 - db.add(list_db) + db.add(list_db) # Add the already attached list_db to mark it dirty for the session await db.flush() - await db.refresh(list_db) - return list_db + + # Re-fetch with relationships for the response + stmt = ( + select(ListModel) + .where(ListModel.id == list_db.id) + .options( + selectinload(ListModel.creator), + selectinload(ListModel.group) + # selectinload(ListModel.items) # Optionally add if items are always needed in response + ) + ) + result = await db.execute(stmt) + updated_list = result.scalar_one_or_none() + + if updated_list is None: # Should not happen + await transaction.rollback() + raise ListOperationError("Failed to load list after update.") + + await transaction.commit() + return updated_list except IntegrityError as e: - await db.rollback() + # Ensure rollback if not handled by context manager (though it should be) + if db.in_transaction(): await db.rollback() raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") except OperationalError as e: - await db.rollback() + if db.in_transaction(): await db.rollback() raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") except ConflictError: - await db.rollback() + # Already rolled back or will be by context manager if transaction was started here raise except SQLAlchemyError as e: - await db.rollback() + if db.in_transaction(): await db.rollback() raise DatabaseTransactionError(f"Failed to update list: {str(e)}") async def delete_list(db: AsyncSession, list_db: ListModel) -> None: """Deletes a list record. Version check should be done by the caller (API endpoint).""" try: - async with db.begin(): + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction await db.delete(list_db) - return None + await transaction.commit() # Explicit commit + # return None # Already implicitly returns None except OperationalError as e: - await db.rollback() + # Rollback should be handled by the async with block on exception raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") except SQLAlchemyError as e: - await db.rollback() + # Rollback should be handled by the async with block on exception raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: @@ -212,39 +250,48 @@ async def get_list_by_name_and_group( db: AsyncSession, name: str, group_id: Optional[int], - user_id: int + user_id: int # user_id is for permission check, not direct list attribute ) -> Optional[ListModel]: """ Gets a list by name and group, ensuring the user has permission to access it. Used for conflict resolution when creating lists. """ try: - # Build the base query - query = select(ListModel).where(ListModel.name == name) + # Base query for the list itself + base_query = select(ListModel).where(ListModel.name == name) - # Add group condition if group_id is not None: - query = query.where(ListModel.group_id == group_id) + base_query = base_query.where(ListModel.group_id == group_id) else: - query = query.where(ListModel.group_id.is_(None)) + base_query = base_query.where(ListModel.group_id.is_(None)) - # Add permission conditions - conditions = [ - ListModel.created_by_id == user_id # User is creator - ] - if group_id is not None: - # User is member of the group - conditions.append( - and_( - ListModel.group_id == group_id, - ListModel.created_by_id != user_id # Not the creator - ) - ) + # Add eager loading for common relationships + base_query = base_query.options( + selectinload(ListModel.creator), + selectinload(ListModel.group) + ) - query = query.where(or_(*conditions)) + list_result = await db.execute(base_query) + target_list = list_result.scalar_one_or_none() + + if not target_list: + return None + + # Permission check + is_creator = target_list.created_by_id == user_id + + if is_creator: + return target_list + + if target_list.group_id: + from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction + is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id) + if is_member_of_group: + return target_list + + # If not creator and (not a group list or not a member of the group list) + return None - result = await db.execute(query) - return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") except SQLAlchemyError as e: diff --git a/be/app/crud/settlement.py b/be/app/crud/settlement.py index fdb0b37..30ec0ce 100644 --- a/be/app/crud/settlement.py +++ b/be/app/crud/settlement.py @@ -3,84 +3,135 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_ -from decimal import Decimal +from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError +from decimal import Decimal, ROUND_HALF_UP from typing import List as PyList, Optional, Sequence from datetime import datetime, timezone from app.models import ( Settlement as SettlementModel, User as UserModel, - Group as GroupModel + Group as GroupModel, + UserGroup as UserGroupModel +) +from app.schemas.expense import SettlementCreate, SettlementUpdate +from app.core.exceptions import ( + UserNotFoundError, + GroupNotFoundError, + InvalidOperationError, + DatabaseConnectionError, + DatabaseIntegrityError, + DatabaseQueryError, + DatabaseTransactionError, + SettlementOperationError, + ConflictError ) -from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet -from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: """Creates a new settlement record.""" - - # Validate Payer, Payee, and Group exist - payer = await db.get(UserModel, settlement_in.paid_by_user_id) - if not payer: - raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") - - payee = await db.get(UserModel, settlement_in.paid_to_user_id) - if not payee: - raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") - - if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: - raise InvalidOperationError("Payer and Payee cannot be the same user.") - - group = await db.get(GroupModel, settlement_in.group_id) - if not group: - raise GroupNotFoundError(settlement_in.group_id) - - # Optional: Check if current_user_id is part of the group or is one of the parties involved - # This is more of an API-level permission check but could be added here if strict. - # For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]: - # is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id)) - # if not is_in_group.first(): - # raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.") - - db_settlement = SettlementModel( - group_id=settlement_in.group_id, - paid_by_user_id=settlement_in.paid_by_user_id, - paid_to_user_id=settlement_in.paid_to_user_id, - amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), - settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), - description=settlement_in.description - # created_by_user_id = current_user_id # Optional: Who recorded this settlement - ) - db.add(db_settlement) try: - await db.commit() - await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"]) - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to save settlement: {str(e)}") - - return db_settlement + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + payer = await db.get(UserModel, settlement_in.paid_by_user_id) + if not payer: + raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") + + payee = await db.get(UserModel, settlement_in.paid_to_user_id) + if not payee: + raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") + + if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: + raise InvalidOperationError("Payer and Payee cannot be the same user.") + + group = await db.get(GroupModel, settlement_in.group_id) + if not group: + raise GroupNotFoundError(settlement_in.group_id) + + # Permission check example (can be in API layer too) + # if current_user_id not in [payer.id, payee.id]: + # is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1) + # is_member_result = await db.execute(is_member_stmt) + # if not is_member_result.scalar_one_or_none(): + # raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.") + + db_settlement = SettlementModel( + group_id=settlement_in.group_id, + paid_by_user_id=settlement_in.paid_by_user_id, + paid_to_user_id=settlement_in.paid_to_user_id, + amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), + settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), + description=settlement_in.description + # created_by_user_id = current_user_id # Optional: Who recorded this settlement + ) + db.add(db_settlement) + await db.flush() + + # Re-fetch with relationships + stmt = ( + select(SettlementModel) + .where(SettlementModel.id == db_settlement.id) + .options( + selectinload(SettlementModel.payer), + selectinload(SettlementModel.payee), + selectinload(SettlementModel.group) + ) + ) + result = await db.execute(stmt) + loaded_settlement = result.scalar_one_or_none() + + if loaded_settlement is None: + await transaction.rollback() # Should be handled by context manager + raise SettlementOperationError("Failed to load settlement after creation.") + + await transaction.commit() + return loaded_settlement + except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e: + # These are validation errors, re-raise them. + # If a transaction was started, context manager handles rollback. + raise + except IntegrityError as e: + raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}") + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}") + async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: - result = await db.execute( - select(SettlementModel) - .options( - selectinload(SettlementModel.payer), - selectinload(SettlementModel.payee), - selectinload(SettlementModel.group) + try: + result = await db.execute( + select(SettlementModel) + .options( + selectinload(SettlementModel.payer), + selectinload(SettlementModel.payee), + selectinload(SettlementModel.group) + ) + .where(SettlementModel.id == settlement_id) ) - .where(SettlementModel.id == settlement_id) - ) - return result.scalars().first() + return result.scalars().first() + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}") async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: - result = await db.execute( - select(SettlementModel) - .where(SettlementModel.group_id == group_id) - .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee)) - ) - return result.scalars().all() + try: + result = await db.execute( + select(SettlementModel) + .where(SettlementModel.group_id == group_id) + .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) + .offset(skip).limit(limit) + .options( + selectinload(SettlementModel.payer), + selectinload(SettlementModel.payee), + selectinload(SettlementModel.group) + ) + ) + return result.scalars().all() + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}") + async def get_settlements_involving_user( db: AsyncSession, @@ -89,18 +140,28 @@ async def get_settlements_involving_user( skip: int = 0, limit: int = 100 ) -> Sequence[SettlementModel]: - query = ( - select(SettlementModel) - .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) - .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) - .offset(skip).limit(limit) - .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group)) - ) - if group_id: - query = query.where(SettlementModel.group_id == group_id) - - result = await db.execute(query) - return result.scalars().all() + try: + query = ( + select(SettlementModel) + .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) + .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) + .offset(skip).limit(limit) + .options( + selectinload(SettlementModel.payer), + selectinload(SettlementModel.payee), + selectinload(SettlementModel.group) + ) + ) + if group_id: + query = query.where(SettlementModel.group_id == group_id) + + result = await db.execute(query) + return result.scalars().all() + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}") + async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: """ @@ -108,58 +169,100 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se Only allows updates to description and settlement_date. Requires version matching for optimistic locking. Assumes SettlementUpdate schema includes a version field. + Assumes SettlementModel has version and updated_at fields. """ - # Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently. - if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version: - raise InvalidOperationError( - f"Settlement (ID: {settlement_db.id}) has been modified. " - f"Your version does not match current version {settlement_db.version}. Please refresh.", - # status_code=status.HTTP_409_CONFLICT - ) - - update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) - allowed_to_update = {"description", "settlement_date"} - updated_something = False - - for field, value in update_data.items(): - if field in allowed_to_update: - setattr(settlement_db, field, value) - updated_something = True - else: - raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.") - - if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): - pass # No actual updatable fields provided, but version matched. - - settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing - settlement_db.updated_at = datetime.now(timezone.utc) - try: - await db.commit() - await db.refresh(settlement_db) - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to update settlement: {str(e)}") - - return settlement_db + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + # Ensure the settlement_db passed is managed by the current session if not already. + # This is usually true if fetched by an endpoint dependency using the same session. + # If not, `db.add(settlement_db)` might be needed before modification if it's detached. + + if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'): + raise InvalidOperationError("Version field is missing in model or input for optimistic locking.") + + if settlement_db.version != settlement_in.version: + raise ConflictError( # Make sure ConflictError is defined in exceptions + f"Settlement (ID: {settlement_db.id}) has been modified. " + f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh." + ) + + update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) + allowed_to_update = {"description", "settlement_date"} + updated_something = False + + for field, value in update_data.items(): + if field in allowed_to_update: + setattr(settlement_db, field, value) + updated_something = True + # Silently ignore fields not allowed to update or raise error: + # else: + # raise InvalidOperationError(f"Field '{field}' cannot be updated.") + + if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): + # No updatable fields were actually provided, or they didn't change + # Still, we might want to return the re-loaded settlement if version matched. + pass + + settlement_db.version += 1 + settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field + + db.add(settlement_db) # Mark as dirty + await db.flush() + + # Re-fetch with relationships + stmt = ( + select(SettlementModel) + .where(SettlementModel.id == settlement_db.id) + .options( + selectinload(SettlementModel.payer), + selectinload(SettlementModel.payee), + selectinload(SettlementModel.group) + ) + ) + result = await db.execute(stmt) + updated_settlement = result.scalar_one_or_none() + + if updated_settlement is None: # Should not happen + await transaction.rollback() + raise SettlementOperationError("Failed to load settlement after update.") + + await transaction.commit() + return updated_settlement + except ConflictError as e: # ConflictError should be defined in exceptions + raise + except InvalidOperationError as e: + raise + except IntegrityError as e: + raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}") + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}") + async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: """ Deletes a settlement. Requires version matching if expected_version is provided. Assumes SettlementModel has a version field. """ - if expected_version is not None: - if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: - raise InvalidOperationError( - f"Settlement (ID: {settlement_db.id}) cannot be deleted. " - f"Expected version {expected_version} does not match current version. Please refresh.", - # status_code=status.HTTP_409_CONFLICT - ) - - await db.delete(settlement_db) try: - await db.commit() - except Exception as e: - await db.rollback() - raise InvalidOperationError(f"Failed to delete settlement: {str(e)}") - return None \ No newline at end of file + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + if expected_version is not None: + if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: + raise ConflictError( # Make sure ConflictError is defined + f"Settlement (ID: {settlement_db.id}) cannot be deleted. " + f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh." + ) + + await db.delete(settlement_db) + await transaction.commit() + except ConflictError as e: # ConflictError should be defined + raise + except OperationalError as e: + raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}") + except SQLAlchemyError as e: + raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}") + +# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions +# Example: class SettlementOperationError(AppException): pass +# Example: class ConflictError(AppException): status_code = 409 \ No newline at end of file diff --git a/be/app/crud/user.py b/be/app/crud/user.py index f0aad92..f36201b 100644 --- a/be/app/crud/user.py +++ b/be/app/crud/user.py @@ -1,10 +1,11 @@ # app/crud/user.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional -from app.models import User as UserModel # Alias to avoid name clash +from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload from app.schemas.user import UserCreate from app.core.security import hash_password from app.core.exceptions import ( @@ -13,14 +14,26 @@ from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, - DatabaseTransactionError + DatabaseTransactionError, + UserOperationError # Add if specific user operation errors are needed ) async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: - """Fetches a user from the database by email.""" + """Fetches a user from the database by email, with common relationships.""" try: - async with db.begin(): - result = await db.execute(select(UserModel).filter(UserModel.email == email)) + # db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added. + # For a single select, it can be omitted if preferred, session handles connection. + async with db.begin(): # Or remove if only a single select operation + stmt = ( + select(UserModel) + .filter(UserModel.email == email) + .options( + selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of + selectinload(UserModel.created_groups) # Groups user created + # Add other relationships as needed by UserPublic schema + ) + ) + result = await db.execute(stmt) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") @@ -28,24 +41,46 @@ async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel] raise DatabaseQueryError(f"Failed to query user: {str(e)}") async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: - """Creates a new user record in the database.""" + """Creates a new user record in the database with common relationships loaded.""" try: - async with db.begin(): + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: _hashed_password = hash_password(user_in.password) db_user = UserModel( email=user_in.email, - password_hash=_hashed_password, + hashed_password=_hashed_password, # Field name in model is hashed_password name=user_in.name ) db.add(db_user) - await db.flush() # Flush to get DB-generated values - await db.refresh(db_user) - return db_user + await db.flush() # Flush to get DB-generated values like ID + + # Re-fetch with relationships + stmt = ( + select(UserModel) + .where(UserModel.id == db_user.id) + .options( + selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), + selectinload(UserModel.created_groups) + # Add other relationships as needed by UserPublic schema + ) + ) + result = await db.execute(stmt) + loaded_user = result.scalar_one_or_none() + + if loaded_user is None: + await transaction.rollback() # Should be handled by context manager, but explicit + raise UserOperationError("Failed to load user after creation.") # Define UserOperationError + + await transaction.commit() + return loaded_user except IntegrityError as e: - if "unique constraint" in str(e).lower(): - raise EmailAlreadyRegisteredError() - raise DatabaseIntegrityError(f"Failed to create user: {str(e)}") + # Context manager handles rollback on error + if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()): + raise EmailAlreadyRegisteredError(email=user_in.email) + raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}") except OperationalError as e: - raise DatabaseConnectionError(f"Database connection error: {str(e)}") + raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}") except SQLAlchemyError as e: - raise DatabaseTransactionError(f"Failed to create user: {str(e)}") \ No newline at end of file + raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}") + +# Ensure UserOperationError is defined in app.core.exceptions if used +# Example: class UserOperationError(AppException): pass \ No newline at end of file diff --git a/be/app/database.py b/be/app/database.py index f14bb93..db7768d 100644 --- a/be/app/database.py +++ b/be/app/database.py @@ -36,15 +36,9 @@ async def get_async_session() -> AsyncSession: # type: ignore Ensures the session is closed after the request. """ async with AsyncSessionLocal() as session: - try: - yield session - # Commit the transaction if no errors occurred - await session.commit() - except Exception: - await session.rollback() - raise - finally: - await session.close() # Not strictly necessary with async context manager, but explicit + yield session + # The 'async with' block handles session.close() automatically. + # Commit/rollback should be handled by the functions using the session. # Alias for backward compatibility get_db = get_async_session \ No newline at end of file