Refactor CRUD operations across multiple modules to standardize transaction handling using context managers, improving error logging and rollback mechanisms. Enhance error handling for database operations in expense, group, invite, item, list, settlement, and user modules, ensuring specific exceptions are raised for integrity and connection issues.

This commit is contained in:
mohamad 2025-05-20 01:18:49 +02:00
parent e4175db4aa
commit 98b2f907de
7 changed files with 192 additions and 144 deletions

View File

@ -3,6 +3,7 @@ import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone
@ -183,10 +184,11 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
loaded_expense = result.scalar_one_or_none()
if loaded_expense is None:
await transaction.rollback() # Should be handled by context manager
# The context manager will handle rollback if an exception is raised.
# await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
await transaction.commit()
# await transaction.commit() # Explicit commit removed, context manager handles it.
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
@ -564,18 +566,27 @@ async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in:
# For now, if only version was sent, we still increment if it matched.
pass # Or raise InvalidOperationError("No updatable fields provided.")
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
try:
await db.commit()
await db.refresh(expense_db)
except Exception as e:
await db.rollback()
# Consider specific DB error types if needed
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
expense_db.version += 1
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
await db.flush() # Persist changes to the DB and run constraints
await db.refresh(expense_db) # Refresh the object from the DB
return expense_db
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
return expense_db
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
"""
@ -589,12 +600,20 @@ async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_ve
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(expense_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(expense_db)
await db.flush() # Ensure the delete operation is sent to the database
except InvalidOperationError: # Re-raise validation errors
raise
except IntegrityError as e:
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
# The transaction context manager (begin_nested/begin) handles rollback.
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
return None
# Note: The InvalidOperationError is a simple ValueError placeholder.

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List
from sqlalchemy import delete, func
import logging # Add logging import
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate
@ -20,24 +21,16 @@ from app.core.exceptions import (
GroupPermissionError # Import GroupPermissionError
)
logger = logging.getLogger(__name__) # Initialize logger
# --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner."""
try:
# Defensive check: if a transaction is already active, try to roll it back.
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration).
if db.in_transaction():
# Log this occurrence if possible, as it's unexpected.
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
try:
await db.rollback() # Attempt to clear any existing transaction
except SQLAlchemyError as e_rb:
# Log e_rb if possible
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
# Re-raise or handle as a critical error, as the session state is uncertain.
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
async with db.begin(): # Now attempt to start a clean transaction
# Use the composability pattern for transactions as per fastapi-db-strategy.
# This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
# or starts a new transaction if called outside of one (e.g., from a script).
async with db.begin_nested() if db.in_transaction() else db.begin():
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group)
await db.flush() # Assigns ID to db_group
@ -67,10 +60,13 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
return loaded_group
except IntegrityError as e:
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
@ -143,7 +139,7 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
return None
# Use a single transaction
async with db.begin():
async with db.begin_nested() if db.in_transaction() else db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group)
await db.flush() # Assigns ID to db_user_group
@ -165,16 +161,19 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
return loaded_user_group
except IntegrityError as e:
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group."""
try:
async with db.begin():
async with db.begin_nested() if db.in_transaction() else db.begin():
result = await db.execute(
delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
@ -182,8 +181,10 @@ async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:

View File

@ -1,4 +1,5 @@
# app/crud/invite.py
import logging # Add logging import
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
@ -17,93 +18,100 @@ from app.core.exceptions import (
InviteOperationError # Add new specific exception
)
logger = logging.getLogger(__name__) # Initialize logger
# Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
"""Deactivates all currently active invite codes for a specific group."""
try:
stmt = (
select(InviteModel)
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
)
result = await db.execute(stmt)
active_invites = result.scalars().all()
async with db.begin_nested() if db.in_transaction() else db.begin():
stmt = (
select(InviteModel)
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
)
result = await db.execute(stmt)
active_invites = result.scalars().all()
if not active_invites:
return # No active invites to deactivate
if not active_invites:
return # No active invites to deactivate
for invite in active_invites:
invite.is_active = False
db.add(invite)
for invite in active_invites:
invite.is_active = False
db.add(invite)
await db.flush() # Flush changes within this transaction block
# await db.flush() # Removed: Rely on caller to flush/commit
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
except OperationalError as e:
# It's better to let the caller handle rollback or commit based on overall operation success
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
# Deactivate existing active invites for this group
await deactivate_all_active_invites_for_group(db, group_id)
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Deactivate existing active invites for this group
await deactivate_all_active_invites_for_group(db, group_id)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
# Removed explicit transaction block here, rely on FastAPI's request-level transaction.
# Final check for code collision (less critical now without explicit nested transaction rollback on collision)
# but still good to prevent duplicate active codes if possible, though the deactivate step helps.
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
# This is now more of a rare edge case if deactivate worked and code generation is diverse.
# Depending on strictness, could raise an error or just log and proceed,
# relying on the previous deactivation to ensure only one is active.
# For now, let's raise to be safe, as it implies a very quick collision.
raise InviteOperationError("Invite code collision detected just before creation attempt.")
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
raise InviteOperationError("Invite code collision detected just before creation attempt.")
db_invite = InviteModel(
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.flush() # Flush to get ID for re-fetch and ensure it's in session before potential re-fetch.
db_invite = InviteModel(
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.flush()
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
# This would be an issue, implies flush didn't work or ID was wrong.
# The main transaction will rollback if this exception is raised.
raise InviteOperationError("Failed to load invite after creation and flush.")
if loaded_invite is None:
raise InviteOperationError("Failed to load invite after creation and flush.")
return loaded_invite
# No explicit commit here, FastAPI handles it for the request.
return loaded_invite
except InviteOperationError: # Already specific, re-raise
raise
except IntegrityError as e:
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
"""Gets the currently active and non-expired invite for a specific group."""
@ -125,8 +133,10 @@ async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Option
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
@ -172,14 +182,14 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode
updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in
await transaction.rollback()
raise InviteOperationError("Failed to load invite after deactivation.")
await transaction.commit()
return updated_invite
except OperationalError as e:
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions

View File

@ -6,6 +6,7 @@ from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from datetime import datetime, timezone
import logging # Add logging import
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
from app.schemas.item import ItemCreate, ItemUpdate
@ -19,6 +20,8 @@ from app.core.exceptions import (
ItemOperationError # Add if specific item operation errors are needed
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list."""
try:
@ -46,17 +49,18 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
loaded_item = result.scalar_one_or_none()
if loaded_item is None:
await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity
# await transaction.rollback() # Redundant, context manager handles rollback on exception
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
await transaction.commit()
return loaded_item
except IntegrityError as e:
# Context manager handles rollback on error
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
# and context manager handles rollback.
@ -144,15 +148,17 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate,
# Rollback will be handled by context manager on raise
raise ItemOperationError("Failed to load item after update.")
await transaction.commit()
return updated_item
except IntegrityError as e:
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
@ -160,13 +166,13 @@ async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.delete(item_db)
await transaction.commit()
# await transaction.commit() # Removed
# No return needed for None
except OperationalError as e:
# Rollback handled by context manager
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e:
# Rollback handled by context manager
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
# Ensure ItemOperationError is defined in app.core.exceptions if used

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
import logging # Add logging import
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
@ -21,6 +22,8 @@ from app.core.exceptions import (
ListOperationError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
@ -49,16 +52,17 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
loaded_list = result.scalar_one_or_none()
if loaded_list is None:
await transaction.rollback()
raise ListOperationError("Failed to load list after creation.")
await transaction.commit()
return loaded_list
except IntegrityError as e:
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
@ -80,8 +84,11 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Consider if items are needed for list previews
selectinload(ListModel.group),
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
.order_by(ListModel.updated_at.desc())
)
@ -123,7 +130,6 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
await transaction.rollback() # Rollback before raising
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
@ -153,23 +159,19 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
await transaction.rollback()
raise ListOperationError("Failed to load list after update.")
await transaction.commit()
return updated_list
except IntegrityError as e:
# Ensure rollback if not handled by context manager (though it should be)
if db.in_transaction(): await db.rollback()
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
if db.in_transaction(): await db.rollback()
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
# Already rolled back or will be by context manager if transaction was started here
raise
except SQLAlchemyError as e:
if db.in_transaction(): await db.rollback()
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
@ -177,13 +179,11 @@ async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db)
await transaction.commit() # Explicit commit
# return None # Already implicitly returns None
except OperationalError as e:
# Rollback should be handled by the async with block on exception
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
# Rollback should be handled by the async with block on exception
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:

View File

@ -7,6 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
import logging # Add logging import
from app.models import (
Settlement as SettlementModel,
@ -27,10 +28,12 @@ from app.core.exceptions import (
ConflictError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
async with db.begin_nested() if db.in_transaction() else db.begin():
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
@ -80,20 +83,21 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
loaded_settlement = result.scalar_one_or_none()
if loaded_settlement is None:
await transaction.rollback() # Should be handled by context manager
raise SettlementOperationError("Failed to load settlement after creation.")
await transaction.commit()
return loaded_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are validation errors, re-raise them.
# If a transaction was started, context manager handles rollback.
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
@ -111,8 +115,10 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional
)
return result.scalars().first()
except OperationalError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
except SQLAlchemyError as e:
# Optional: logger.warning or info if needed for read operations
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
@ -176,7 +182,7 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
Assumes SettlementModel has version and updated_at fields.
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
async with db.begin_nested() if db.in_transaction() else db.begin():
# Ensure the settlement_db passed is managed by the current session if not already.
# This is usually true if fetched by an endpoint dependency using the same session.
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
@ -228,20 +234,21 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
updated_settlement = result.scalar_one_or_none()
if updated_settlement is None: # Should not happen
await transaction.rollback()
raise SettlementOperationError("Failed to load settlement after update.")
await transaction.commit()
return updated_settlement
except ConflictError as e: # ConflictError should be defined in exceptions
raise
except InvalidOperationError as e:
raise
except IntegrityError as e:
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
@ -251,7 +258,7 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
Assumes SettlementModel has a version field.
"""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
async with db.begin_nested() if db.in_transaction() else db.begin():
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise ConflictError( # Make sure ConflictError is defined
@ -260,12 +267,13 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
)
await db.delete(settlement_db)
await transaction.commit()
except ConflictError as e: # ConflictError should be defined
raise
except OperationalError as e:
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions

View File

@ -4,6 +4,7 @@ from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional
import logging # Add logging import
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
from app.schemas.user import UserCreate
@ -18,26 +19,29 @@ from app.core.exceptions import (
UserOperationError # Add if specific user operation errors are needed
)
logger = logging.getLogger(__name__) # Initialize logger
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email, with common relationships."""
try:
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
# For a single select, it can be omitted if preferred, session handles connection.
async with db.begin(): # Or remove if only a single select operation
stmt = (
select(UserModel)
.filter(UserModel.email == email)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.created_groups) # Groups user created
# Add other relationships as needed by UserPublic schema
)
stmt = (
select(UserModel)
.filter(UserModel.email == email)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.created_groups) # Groups user created
# Add other relationships as needed by UserPublic schema
)
result = await db.execute(stmt)
return result.scalars().first()
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
@ -67,19 +71,19 @@ async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
loaded_user = result.scalar_one_or_none()
if loaded_user is None:
await transaction.rollback() # Should be handled by context manager, but explicit
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
await transaction.commit()
return loaded_user
except IntegrityError as e:
# Context manager handles rollback on error
logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
raise EmailAlreadyRegisteredError(email=user_in.email)
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
# Ensure UserOperationError is defined in app.core.exceptions if used