Refactor CRUD operations across multiple modules to standardize transaction handling using context managers, improving error logging and rollback mechanisms. Enhance error handling for database operations in expense, group, invite, item, list, settlement, and user modules, ensuring specific exceptions are raised for integrity and connection issues.

This commit is contained in:
mohamad 2025-05-20 01:18:49 +02:00
parent e4175db4aa
commit 98b2f907de
7 changed files with 192 additions and 144 deletions

View File

@ -3,6 +3,7 @@ import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone from datetime import datetime, timezone # Added timezone
@ -183,10 +184,11 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
loaded_expense = result.scalar_one_or_none() loaded_expense = result.scalar_one_or_none()
if loaded_expense is None: if loaded_expense is None:
await transaction.rollback() # Should be handled by context manager # The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.") raise ExpenseOperationError("Failed to load expense after creation.")
await transaction.commit() # await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e: except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
@ -564,18 +566,27 @@ async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in:
# For now, if only version was sent, we still increment if it matched. # For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.") pass # Or raise InvalidOperationError("No updatable fields provided.")
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
try: try:
await db.commit() async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.refresh(expense_db) expense_db.version += 1
except Exception as e: expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
await db.rollback() # db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
# Consider specific DB error types if needed
raise InvalidOperationError(f"Failed to update expense: {str(e)}") await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
return expense_db
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None: async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
""" """
@ -589,12 +600,20 @@ async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_ve
# status_code=status.HTTP_409_CONFLICT # status_code=status.HTTP_409_CONFLICT
) )
await db.delete(expense_db)
try: try:
await db.commit() async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
except Exception as e: await db.delete(expense_db)
await db.rollback() await db.flush() # Ensure the delete operation is sent to the database
raise InvalidOperationError(f"Failed to delete expense: {str(e)}") except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
return None return None
# Note: The InvalidOperationError is a simple ValueError placeholder. # Note: The InvalidOperationError is a simple ValueError placeholder.

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List from typing import Optional, List
from sqlalchemy import delete, func from sqlalchemy import delete, func
import logging # Add logging import
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate from app.schemas.group import GroupCreate
@ -20,24 +21,16 @@ from app.core.exceptions import (
GroupPermissionError # Import GroupPermissionError GroupPermissionError # Import GroupPermissionError
) )
logger = logging.getLogger(__name__) # Initialize logger
# --- Group CRUD --- # --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner.""" """Creates a group and adds the creator as the owner."""
try: try:
# Defensive check: if a transaction is already active, try to roll it back. # Use the composability pattern for transactions as per fastapi-db-strategy.
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration). # This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
if db.in_transaction(): # or starts a new transaction if called outside of one (e.g., from a script).
# Log this occurrence if possible, as it's unexpected. async with db.begin_nested() if db.in_transaction() else db.begin():
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
try:
await db.rollback() # Attempt to clear any existing transaction
except SQLAlchemyError as e_rb:
# Log e_rb if possible
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
# Re-raise or handle as a critical error, as the session state is uncertain.
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
async with db.begin(): # Now attempt to start a clean transaction
db_group = GroupModel(name=group_in.name, created_by_id=creator_id) db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group) db.add(db_group)
await db.flush() # Assigns ID to db_group await db.flush() # Assigns ID to db_group
@ -67,10 +60,13 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
return loaded_group return loaded_group
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}") raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}") raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}") raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
@ -143,7 +139,7 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
return None return None
# Use a single transaction # Use a single transaction
async with db.begin(): async with db.begin_nested() if db.in_transaction() else db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group) db.add(db_user_group)
await db.flush() # Assigns ID to db_user_group await db.flush() # Assigns ID to db_user_group
@ -165,16 +161,19 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
return loaded_user_group return loaded_user_group
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}") raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}") raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool: async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group.""" """Removes a user from a group."""
try: try:
async with db.begin(): async with db.begin_nested() if db.in_transaction() else db.begin():
result = await db.execute( result = await db.execute(
delete(UserGroupModel) delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
@ -182,8 +181,10 @@ async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int)
) )
return result.scalar_one_or_none() is not None return result.scalar_one_or_none() is not None
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}") raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int: async def get_group_member_count(db: AsyncSession, group_id: int) -> int:

View File

@ -1,4 +1,5 @@
# app/crud/invite.py # app/crud/invite.py
import logging # Add logging import
import secrets import secrets
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -17,93 +18,100 @@ from app.core.exceptions import (
InviteOperationError # Add new specific exception InviteOperationError # Add new specific exception
) )
logger = logging.getLogger(__name__) # Initialize logger
# Invite codes should be reasonably unique, but handle potential collision # Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5 MAX_CODE_GENERATION_ATTEMPTS = 5
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int): async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
"""Deactivates all currently active invite codes for a specific group.""" """Deactivates all currently active invite codes for a specific group."""
try: try:
stmt = ( async with db.begin_nested() if db.in_transaction() else db.begin():
select(InviteModel) stmt = (
.where(InviteModel.group_id == group_id, InviteModel.is_active == True) select(InviteModel)
) .where(InviteModel.group_id == group_id, InviteModel.is_active == True)
result = await db.execute(stmt) )
active_invites = result.scalars().all() result = await db.execute(stmt)
active_invites = result.scalars().all()
if not active_invites: if not active_invites:
return # No active invites to deactivate return # No active invites to deactivate
for invite in active_invites: for invite in active_invites:
invite.is_active = False invite.is_active = False
db.add(invite) db.add(invite)
await db.flush() # Flush changes within this transaction block
# await db.flush() # Removed: Rely on caller to flush/commit # await db.flush() # Removed: Rely on caller to flush/commit
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit. # No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
except OperationalError as e: except OperationalError as e:
# It's better to let the caller handle rollback or commit based on overall operation success logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}") raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}") raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
"""Creates a new invite code for a group, deactivating any existing active ones for that group first.""" """Creates a new invite code for a group, deactivating any existing active ones for that group first."""
# Deactivate existing active invites for this group try:
await deactivate_all_active_invites_for_group(db, group_id) async with db.begin_nested() if db.in_transaction() else db.begin():
# Deactivate existing active invites for this group
await deactivate_all_active_invites_for_group(db, group_id)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS): for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16) potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt) existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None: if existing_result.scalar_one_or_none() is None:
break break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1: if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.") raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
# Removed explicit transaction block here, rely on FastAPI's request-level transaction. final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
# Final check for code collision (less critical now without explicit nested transaction rollback on collision) final_check_result = await db.execute(final_check_stmt)
# but still good to prevent duplicate active codes if possible, though the deactivate step helps. if final_check_result.scalar_one_or_none() is not None:
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) raise InviteOperationError("Invite code collision detected just before creation attempt.")
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
# This is now more of a rare edge case if deactivate worked and code generation is diverse.
# Depending on strictness, could raise an error or just log and proceed,
# relying on the previous deactivation to ensure only one is active.
# For now, let's raise to be safe, as it implies a very quick collision.
raise InviteOperationError("Invite code collision detected just before creation attempt.")
db_invite = InviteModel( db_invite = InviteModel(
code=potential_code, code=potential_code,
group_id=group_id, group_id=group_id,
created_by_id=creator_id, created_by_id=creator_id,
expires_at=expires_at, expires_at=expires_at,
is_active=True is_active=True
) )
db.add(db_invite) db.add(db_invite)
await db.flush() # Flush to get ID for re-fetch and ensure it's in session before potential re-fetch. await db.flush()
# Re-fetch with relationships stmt = (
stmt = ( select(InviteModel)
select(InviteModel) .where(InviteModel.id == db_invite.id)
.where(InviteModel.id == db_invite.id) .options(
.options( selectinload(InviteModel.group),
selectinload(InviteModel.group), selectinload(InviteModel.creator)
selectinload(InviteModel.creator) )
) )
) result = await db.execute(stmt)
result = await db.execute(stmt) loaded_invite = result.scalar_one_or_none()
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None: if loaded_invite is None:
# This would be an issue, implies flush didn't work or ID was wrong. raise InviteOperationError("Failed to load invite after creation and flush.")
# The main transaction will rollback if this exception is raised.
raise InviteOperationError("Failed to load invite after creation and flush.")
return loaded_invite return loaded_invite
# No explicit commit here, FastAPI handles it for the request. except InviteOperationError: # Already specific, re-raise
raise
except IntegrityError as e:
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]: async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
"""Gets the currently active and non-expired invite for a specific group.""" """Gets the currently active and non-expired invite for a specific group."""
@ -125,8 +133,10 @@ async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Option
result = await db.execute(stmt) result = await db.execute(stmt)
return result.scalars().first() return result.scalars().first()
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}") raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}") raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
@ -172,14 +182,14 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode
updated_invite = result.scalar_one_or_none() updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in if updated_invite is None: # Should not happen as invite is passed in
await transaction.rollback()
raise InviteOperationError("Failed to load invite after deactivation.") raise InviteOperationError("Failed to load invite after deactivation.")
await transaction.commit()
return updated_invite return updated_invite
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}") raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}") raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions # Ensure InviteOperationError is defined in app.core.exceptions

View File

@ -6,6 +6,7 @@ from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList from typing import Optional, List as PyList
from datetime import datetime, timezone from datetime import datetime, timezone
import logging # Add logging import
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
from app.schemas.item import ItemCreate, ItemUpdate from app.schemas.item import ItemCreate, ItemUpdate
@ -19,6 +20,8 @@ from app.core.exceptions import (
ItemOperationError # Add if specific item operation errors are needed ItemOperationError # Add if specific item operation errors are needed
) )
logger = logging.getLogger(__name__) # Initialize logger
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list.""" """Creates a new item record for a specific list."""
try: try:
@ -46,17 +49,18 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
loaded_item = result.scalar_one_or_none() loaded_item = result.scalar_one_or_none()
if loaded_item is None: if loaded_item is None:
await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity # await transaction.rollback() # Redundant, context manager handles rollback on exception
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
await transaction.commit()
return loaded_item return loaded_item
except IntegrityError as e: except IntegrityError as e:
# Context manager handles rollback on error logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}") raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create item: {str(e)}") raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
# Removed generic Exception block as SQLAlchemyError should cover DB issues, # Removed generic Exception block as SQLAlchemyError should cover DB issues,
# and context manager handles rollback. # and context manager handles rollback.
@ -144,15 +148,17 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate,
# Rollback will be handled by context manager on raise # Rollback will be handled by context manager on raise
raise ItemOperationError("Failed to load item after update.") raise ItemOperationError("Failed to load item after update.")
await transaction.commit()
return updated_item return updated_item
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError, rollback handled by context manager except ConflictError: # Re-raise ConflictError, rollback handled by context manager
raise raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update item: {str(e)}") raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
@ -160,13 +166,13 @@ async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(item_db) await db.delete(item_db)
await transaction.commit() # await transaction.commit() # Removed
# No return needed for None # No return needed for None
except OperationalError as e: except OperationalError as e:
# Rollback handled by context manager logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
# Rollback handled by context manager logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
# Ensure ItemOperationError is defined in app.core.exceptions if used # Ensure ItemOperationError is defined in app.core.exceptions if used

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList from typing import Optional, List as PyList
import logging # Add logging import
from app.schemas.list import ListStatus from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
@ -21,6 +22,8 @@ from app.core.exceptions import (
ListOperationError ListOperationError
) )
logger = logging.getLogger(__name__) # Initialize logger
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record.""" """Creates a new list record."""
try: try:
@ -49,16 +52,17 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
loaded_list = result.scalar_one_or_none() loaded_list = result.scalar_one_or_none()
if loaded_list is None: if loaded_list is None:
await transaction.rollback()
raise ListOperationError("Failed to load list after creation.") raise ListOperationError("Failed to load list after creation.")
await transaction.commit()
return loaded_list return loaded_list
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}") raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create list: {str(e)}") raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]: async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
@ -80,8 +84,11 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
.where(or_(*conditions)) .where(or_(*conditions))
.options( .options(
selectinload(ListModel.creator), selectinload(ListModel.creator),
selectinload(ListModel.group) selectinload(ListModel.group),
# selectinload(ListModel.items) # Consider if items are needed for list previews selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
) )
.order_by(ListModel.updated_at.desc()) .order_by(ListModel.updated_at.desc())
) )
@ -123,7 +130,6 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
await transaction.rollback() # Rollback before raising
raise ConflictError( raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
@ -153,23 +159,19 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
updated_list = result.scalar_one_or_none() updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen if updated_list is None: # Should not happen
await transaction.rollback()
raise ListOperationError("Failed to load list after update.") raise ListOperationError("Failed to load list after update.")
await transaction.commit()
return updated_list return updated_list
except IntegrityError as e: except IntegrityError as e:
# Ensure rollback if not handled by context manager (though it should be) logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
if db.in_transaction(): await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
if db.in_transaction(): await db.rollback() logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError: except ConflictError:
# Already rolled back or will be by context manager if transaction was started here
raise raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
if db.in_transaction(): await db.rollback() logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update list: {str(e)}") raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None: async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
@ -177,13 +179,11 @@ async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db) await db.delete(list_db)
await transaction.commit() # Explicit commit
# return None # Already implicitly returns None
except OperationalError as e: except OperationalError as e:
# Rollback should be handled by the async with block on exception logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
# Rollback should be handled by the async with block on exception logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:

View File

@ -7,6 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from decimal import Decimal, ROUND_HALF_UP from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Optional, Sequence from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
import logging # Add logging import
from app.models import ( from app.models import (
Settlement as SettlementModel, Settlement as SettlementModel,
@ -27,10 +28,12 @@ from app.core.exceptions import (
ConflictError ConflictError
) )
logger = logging.getLogger(__name__) # Initialize logger
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record.""" """Creates a new settlement record."""
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: async with db.begin_nested() if db.in_transaction() else db.begin():
payer = await db.get(UserModel, settlement_in.paid_by_user_id) payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer: if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
@ -80,20 +83,21 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
loaded_settlement = result.scalar_one_or_none() loaded_settlement = result.scalar_one_or_none()
if loaded_settlement is None: if loaded_settlement is None:
await transaction.rollback() # Should be handled by context manager
raise SettlementOperationError("Failed to load settlement after creation.") raise SettlementOperationError("Failed to load settlement after creation.")
await transaction.commit()
return loaded_settlement return loaded_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e: except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are validation errors, re-raise them. # These are validation errors, re-raise them.
# If a transaction was started, context manager handles rollback. # If a transaction was started, context manager handles rollback.
raise raise
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}") raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}") raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}") raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
@ -111,8 +115,10 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional
) )
return result.scalars().first() return result.scalars().first()
except OperationalError as e: except OperationalError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}") raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}") raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
@ -176,7 +182,7 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
Assumes SettlementModel has version and updated_at fields. Assumes SettlementModel has version and updated_at fields.
""" """
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: async with db.begin_nested() if db.in_transaction() else db.begin():
# Ensure the settlement_db passed is managed by the current session if not already. # Ensure the settlement_db passed is managed by the current session if not already.
# This is usually true if fetched by an endpoint dependency using the same session. # This is usually true if fetched by an endpoint dependency using the same session.
# If not, `db.add(settlement_db)` might be needed before modification if it's detached. # If not, `db.add(settlement_db)` might be needed before modification if it's detached.
@ -228,20 +234,21 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
updated_settlement = result.scalar_one_or_none() updated_settlement = result.scalar_one_or_none()
if updated_settlement is None: # Should not happen if updated_settlement is None: # Should not happen
await transaction.rollback()
raise SettlementOperationError("Failed to load settlement after update.") raise SettlementOperationError("Failed to load settlement after update.")
await transaction.commit()
return updated_settlement return updated_settlement
except ConflictError as e: # ConflictError should be defined in exceptions except ConflictError as e: # ConflictError should be defined in exceptions
raise raise
except InvalidOperationError as e: except InvalidOperationError as e:
raise raise
except IntegrityError as e: except IntegrityError as e:
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}") raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}") raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}") raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
@ -251,7 +258,7 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
Assumes SettlementModel has a version field. Assumes SettlementModel has a version field.
""" """
try: try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: async with db.begin_nested() if db.in_transaction() else db.begin():
if expected_version is not None: if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise ConflictError( # Make sure ConflictError is defined raise ConflictError( # Make sure ConflictError is defined
@ -260,12 +267,13 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
) )
await db.delete(settlement_db) await db.delete(settlement_db)
await transaction.commit()
except ConflictError as e: # ConflictError should be defined except ConflictError as e: # ConflictError should be defined
raise raise
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}") raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}") raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions # Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions

View File

@ -4,6 +4,7 @@ from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional from typing import Optional
import logging # Add logging import
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
from app.schemas.user import UserCreate from app.schemas.user import UserCreate
@ -18,26 +19,29 @@ from app.core.exceptions import (
UserOperationError # Add if specific user operation errors are needed UserOperationError # Add if specific user operation errors are needed
) )
logger = logging.getLogger(__name__) # Initialize logger
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email, with common relationships.""" """Fetches a user from the database by email, with common relationships."""
try: try:
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added. # db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
# For a single select, it can be omitted if preferred, session handles connection. # For a single select, it can be omitted if preferred, session handles connection.
async with db.begin(): # Or remove if only a single select operation stmt = (
stmt = ( select(UserModel)
select(UserModel) .filter(UserModel.email == email)
.filter(UserModel.email == email) .options(
.options( selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of selectinload(UserModel.created_groups) # Groups user created
selectinload(UserModel.created_groups) # Groups user created # Add other relationships as needed by UserPublic schema
# Add other relationships as needed by UserPublic schema
)
) )
result = await db.execute(stmt) )
return result.scalars().first() result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseQueryError(f"Failed to query user: {str(e)}") raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
@ -67,19 +71,19 @@ async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
loaded_user = result.scalar_one_or_none() loaded_user = result.scalar_one_or_none()
if loaded_user is None: if loaded_user is None:
await transaction.rollback() # Should be handled by context manager, but explicit
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
await transaction.commit()
return loaded_user return loaded_user
except IntegrityError as e: except IntegrityError as e:
# Context manager handles rollback on error logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()): if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
raise EmailAlreadyRegisteredError(email=user_in.email) raise EmailAlreadyRegisteredError(email=user_in.email)
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}") raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
except OperationalError as e: except OperationalError as e:
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}") raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}") raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
# Ensure UserOperationError is defined in app.core.exceptions if used # Ensure UserOperationError is defined in app.core.exceptions if used