Refactor CRUD operations across multiple modules to standardize transaction handling using context managers, improving error logging and rollback mechanisms. Enhance error handling for database operations in expense, group, invite, item, list, settlement, and user modules, ensuring specific exceptions are raised for integrity and connection issues.
This commit is contained in:
parent
e4175db4aa
commit
98b2f907de
@ -3,6 +3,7 @@ import logging # Add logging import
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload, joinedload
|
from sqlalchemy.orm import selectinload, joinedload
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
|
||||||
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
||||||
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
||||||
from datetime import datetime, timezone # Added timezone
|
from datetime import datetime, timezone # Added timezone
|
||||||
@ -183,10 +184,11 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
|
|||||||
loaded_expense = result.scalar_one_or_none()
|
loaded_expense = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_expense is None:
|
if loaded_expense is None:
|
||||||
await transaction.rollback() # Should be handled by context manager
|
# The context manager will handle rollback if an exception is raised.
|
||||||
|
# await transaction.rollback() # Should be handled by context manager
|
||||||
raise ExpenseOperationError("Failed to load expense after creation.")
|
raise ExpenseOperationError("Failed to load expense after creation.")
|
||||||
|
|
||||||
await transaction.commit()
|
# await transaction.commit() # Explicit commit removed, context manager handles it.
|
||||||
return loaded_expense
|
return loaded_expense
|
||||||
|
|
||||||
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||||
@ -564,18 +566,27 @@ async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in:
|
|||||||
# For now, if only version was sent, we still increment if it matched.
|
# For now, if only version was sent, we still increment if it matched.
|
||||||
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
pass # Or raise InvalidOperationError("No updatable fields provided.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
expense_db.version += 1
|
expense_db.version += 1
|
||||||
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
|
||||||
|
# db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
|
||||||
|
|
||||||
try:
|
await db.flush() # Persist changes to the DB and run constraints
|
||||||
await db.commit()
|
await db.refresh(expense_db) # Refresh the object from the DB
|
||||||
await db.refresh(expense_db)
|
|
||||||
except Exception as e:
|
|
||||||
await db.rollback()
|
|
||||||
# Consider specific DB error types if needed
|
|
||||||
raise InvalidOperationError(f"Failed to update expense: {str(e)}")
|
|
||||||
|
|
||||||
return expense_db
|
return expense_db
|
||||||
|
except InvalidOperationError: # Re-raise validation errors to be handled by the caller
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
|
||||||
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
||||||
|
logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
|
||||||
|
# No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.
|
||||||
|
|
||||||
|
|
||||||
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -589,12 +600,20 @@ async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_ve
|
|||||||
# status_code=status.HTTP_409_CONFLICT
|
# status_code=status.HTTP_409_CONFLICT
|
||||||
)
|
)
|
||||||
|
|
||||||
await db.delete(expense_db)
|
|
||||||
try:
|
try:
|
||||||
await db.commit()
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
except Exception as e:
|
await db.delete(expense_db)
|
||||||
await db.rollback()
|
await db.flush() # Ensure the delete operation is sent to the database
|
||||||
raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
|
except InvalidOperationError: # Re-raise validation errors
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
|
||||||
|
except SQLAlchemyError as e: # Catch other SQLAlchemy errors
|
||||||
|
logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
|
||||||
|
# The transaction context manager (begin_nested/begin) handles rollback.
|
||||||
|
raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
# Note: The InvalidOperationError is a simple ValueError placeholder.
|
||||||
|
@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload # For eager loading members
|
|||||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from sqlalchemy import delete, func
|
from sqlalchemy import delete, func
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
|
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
|
||||||
from app.schemas.group import GroupCreate
|
from app.schemas.group import GroupCreate
|
||||||
@ -20,24 +21,16 @@ from app.core.exceptions import (
|
|||||||
GroupPermissionError # Import GroupPermissionError
|
GroupPermissionError # Import GroupPermissionError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
# --- Group CRUD ---
|
# --- Group CRUD ---
|
||||||
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
|
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
|
||||||
"""Creates a group and adds the creator as the owner."""
|
"""Creates a group and adds the creator as the owner."""
|
||||||
try:
|
try:
|
||||||
# Defensive check: if a transaction is already active, try to roll it back.
|
# Use the composability pattern for transactions as per fastapi-db-strategy.
|
||||||
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration).
|
# This creates a savepoint if already in a transaction (e.g., from get_transactional_session)
|
||||||
if db.in_transaction():
|
# or starts a new transaction if called outside of one (e.g., from a script).
|
||||||
# Log this occurrence if possible, as it's unexpected.
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
|
|
||||||
try:
|
|
||||||
await db.rollback() # Attempt to clear any existing transaction
|
|
||||||
except SQLAlchemyError as e_rb:
|
|
||||||
# Log e_rb if possible
|
|
||||||
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
|
|
||||||
# Re-raise or handle as a critical error, as the session state is uncertain.
|
|
||||||
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
|
|
||||||
|
|
||||||
async with db.begin(): # Now attempt to start a clean transaction
|
|
||||||
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
|
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
|
||||||
db.add(db_group)
|
db.add(db_group)
|
||||||
await db.flush() # Assigns ID to db_group
|
await db.flush() # Assigns ID to db_group
|
||||||
@ -67,10 +60,13 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
|
|||||||
|
|
||||||
return loaded_group
|
return loaded_group
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during group creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during group creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during group creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
|
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
|
||||||
|
|
||||||
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
||||||
@ -143,7 +139,7 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Use a single transaction
|
# Use a single transaction
|
||||||
async with db.begin():
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
|
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
|
||||||
db.add(db_user_group)
|
db.add(db_user_group)
|
||||||
await db.flush() # Assigns ID to db_user_group
|
await db.flush() # Assigns ID to db_user_group
|
||||||
@ -165,16 +161,19 @@ async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role:
|
|||||||
|
|
||||||
return loaded_user_group
|
return loaded_user_group
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error while adding user to group: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while adding user to group: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while adding user to group: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
|
||||||
|
|
||||||
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
|
||||||
"""Removes a user from a group."""
|
"""Removes a user from a group."""
|
||||||
try:
|
try:
|
||||||
async with db.begin():
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
delete(UserGroupModel)
|
delete(UserGroupModel)
|
||||||
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||||
@ -182,8 +181,10 @@ async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int)
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none() is not None
|
return result.scalar_one_or_none() is not None
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while removing user from group: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while removing user from group: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
|
||||||
|
|
||||||
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
|
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# app/crud/invite.py
|
# app/crud/invite.py
|
||||||
|
import logging # Add logging import
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@ -17,12 +18,15 @@ from app.core.exceptions import (
|
|||||||
InviteOperationError # Add new specific exception
|
InviteOperationError # Add new specific exception
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
# Invite codes should be reasonably unique, but handle potential collision
|
# Invite codes should be reasonably unique, but handle potential collision
|
||||||
MAX_CODE_GENERATION_ATTEMPTS = 5
|
MAX_CODE_GENERATION_ATTEMPTS = 5
|
||||||
|
|
||||||
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
|
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
|
||||||
"""Deactivates all currently active invite codes for a specific group."""
|
"""Deactivates all currently active invite codes for a specific group."""
|
||||||
try:
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
stmt = (
|
stmt = (
|
||||||
select(InviteModel)
|
select(InviteModel)
|
||||||
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
|
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
|
||||||
@ -36,18 +40,22 @@ async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: in
|
|||||||
for invite in active_invites:
|
for invite in active_invites:
|
||||||
invite.is_active = False
|
invite.is_active = False
|
||||||
db.add(invite)
|
db.add(invite)
|
||||||
|
await db.flush() # Flush changes within this transaction block
|
||||||
|
|
||||||
# await db.flush() # Removed: Rely on caller to flush/commit
|
# await db.flush() # Removed: Rely on caller to flush/commit
|
||||||
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
|
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
# It's better to let the caller handle rollback or commit based on overall operation success
|
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
|
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
|
||||||
|
|
||||||
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
|
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
|
||||||
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
|
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
# Deactivate existing active invites for this group
|
# Deactivate existing active invites for this group
|
||||||
await deactivate_all_active_invites_for_group(db, group_id)
|
await deactivate_all_active_invites_for_group(db, group_id)
|
||||||
|
|
||||||
@ -63,16 +71,9 @@ async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expire
|
|||||||
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
||||||
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
||||||
|
|
||||||
# Removed explicit transaction block here, rely on FastAPI's request-level transaction.
|
|
||||||
# Final check for code collision (less critical now without explicit nested transaction rollback on collision)
|
|
||||||
# but still good to prevent duplicate active codes if possible, though the deactivate step helps.
|
|
||||||
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||||
final_check_result = await db.execute(final_check_stmt)
|
final_check_result = await db.execute(final_check_stmt)
|
||||||
if final_check_result.scalar_one_or_none() is not None:
|
if final_check_result.scalar_one_or_none() is not None:
|
||||||
# This is now more of a rare edge case if deactivate worked and code generation is diverse.
|
|
||||||
# Depending on strictness, could raise an error or just log and proceed,
|
|
||||||
# relying on the previous deactivation to ensure only one is active.
|
|
||||||
# For now, let's raise to be safe, as it implies a very quick collision.
|
|
||||||
raise InviteOperationError("Invite code collision detected just before creation attempt.")
|
raise InviteOperationError("Invite code collision detected just before creation attempt.")
|
||||||
|
|
||||||
db_invite = InviteModel(
|
db_invite = InviteModel(
|
||||||
@ -83,9 +84,8 @@ async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expire
|
|||||||
is_active=True
|
is_active=True
|
||||||
)
|
)
|
||||||
db.add(db_invite)
|
db.add(db_invite)
|
||||||
await db.flush() # Flush to get ID for re-fetch and ensure it's in session before potential re-fetch.
|
await db.flush()
|
||||||
|
|
||||||
# Re-fetch with relationships
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(InviteModel)
|
select(InviteModel)
|
||||||
.where(InviteModel.id == db_invite.id)
|
.where(InviteModel.id == db_invite.id)
|
||||||
@ -98,12 +98,20 @@ async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expire
|
|||||||
loaded_invite = result.scalar_one_or_none()
|
loaded_invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_invite is None:
|
if loaded_invite is None:
|
||||||
# This would be an issue, implies flush didn't work or ID was wrong.
|
|
||||||
# The main transaction will rollback if this exception is raised.
|
|
||||||
raise InviteOperationError("Failed to load invite after creation and flush.")
|
raise InviteOperationError("Failed to load invite after creation and flush.")
|
||||||
|
|
||||||
return loaded_invite
|
return loaded_invite
|
||||||
# No explicit commit here, FastAPI handles it for the request.
|
except InviteOperationError: # Already specific, re-raise
|
||||||
|
raise
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
|
||||||
|
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
|
||||||
|
|
||||||
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
|
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
|
||||||
"""Gets the currently active and non-expired invite for a specific group."""
|
"""Gets the currently active and non-expired invite for a specific group."""
|
||||||
@ -125,8 +133,10 @@ async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Option
|
|||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
|
||||||
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
|
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
|
||||||
|
|
||||||
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
||||||
@ -172,14 +182,14 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode
|
|||||||
updated_invite = result.scalar_one_or_none()
|
updated_invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
if updated_invite is None: # Should not happen as invite is passed in
|
if updated_invite is None: # Should not happen as invite is passed in
|
||||||
await transaction.rollback()
|
|
||||||
raise InviteOperationError("Failed to load invite after deactivation.")
|
raise InviteOperationError("Failed to load invite after deactivation.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return updated_invite
|
return updated_invite
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
||||||
|
|
||||||
# Ensure InviteOperationError is defined in app.core.exceptions
|
# Ensure InviteOperationError is defined in app.core.exceptions
|
||||||
|
@ -6,6 +6,7 @@ from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
|
|||||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional, List as PyList
|
from typing import Optional, List as PyList
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
|
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
|
||||||
from app.schemas.item import ItemCreate, ItemUpdate
|
from app.schemas.item import ItemCreate, ItemUpdate
|
||||||
@ -19,6 +20,8 @@ from app.core.exceptions import (
|
|||||||
ItemOperationError # Add if specific item operation errors are needed
|
ItemOperationError # Add if specific item operation errors are needed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
||||||
"""Creates a new item record for a specific list."""
|
"""Creates a new item record for a specific list."""
|
||||||
try:
|
try:
|
||||||
@ -46,17 +49,18 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_
|
|||||||
loaded_item = result.scalar_one_or_none()
|
loaded_item = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_item is None:
|
if loaded_item is None:
|
||||||
await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity
|
# await transaction.rollback() # Redundant, context manager handles rollback on exception
|
||||||
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
|
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return loaded_item
|
return loaded_item
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
# Context manager handles rollback on error
|
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
|
||||||
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
|
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
|
||||||
# and context manager handles rollback.
|
# and context manager handles rollback.
|
||||||
@ -144,15 +148,17 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate,
|
|||||||
# Rollback will be handled by context manager on raise
|
# Rollback will be handled by context manager on raise
|
||||||
raise ItemOperationError("Failed to load item after update.")
|
raise ItemOperationError("Failed to load item after update.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return updated_item
|
return updated_item
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
||||||
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
|
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
|
||||||
raise
|
raise
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
||||||
|
|
||||||
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
||||||
@ -160,13 +166,13 @@ async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
|||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
await db.delete(item_db)
|
await db.delete(item_db)
|
||||||
await transaction.commit()
|
# await transaction.commit() # Removed
|
||||||
# No return needed for None
|
# No return needed for None
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
# Rollback handled by context manager
|
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
# Rollback handled by context manager
|
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
||||||
|
|
||||||
# Ensure ItemOperationError is defined in app.core.exceptions if used
|
# Ensure ItemOperationError is defined in app.core.exceptions if used
|
||||||
|
@ -5,6 +5,7 @@ from sqlalchemy.orm import selectinload, joinedload
|
|||||||
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
|
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
|
||||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional, List as PyList
|
from typing import Optional, List as PyList
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.schemas.list import ListStatus
|
from app.schemas.list import ListStatus
|
||||||
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
|
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
|
||||||
@ -21,6 +22,8 @@ from app.core.exceptions import (
|
|||||||
ListOperationError
|
ListOperationError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
||||||
"""Creates a new list record."""
|
"""Creates a new list record."""
|
||||||
try:
|
try:
|
||||||
@ -49,16 +52,17 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
|
|||||||
loaded_list = result.scalar_one_or_none()
|
loaded_list = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_list is None:
|
if loaded_list is None:
|
||||||
await transaction.rollback()
|
|
||||||
raise ListOperationError("Failed to load list after creation.")
|
raise ListOperationError("Failed to load list after creation.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return loaded_list
|
return loaded_list
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
|
||||||
|
|
||||||
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
|
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
|
||||||
@ -80,8 +84,11 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
|
|||||||
.where(or_(*conditions))
|
.where(or_(*conditions))
|
||||||
.options(
|
.options(
|
||||||
selectinload(ListModel.creator),
|
selectinload(ListModel.creator),
|
||||||
selectinload(ListModel.group)
|
selectinload(ListModel.group),
|
||||||
# selectinload(ListModel.items) # Consider if items are needed for list previews
|
selectinload(ListModel.items).options(
|
||||||
|
joinedload(ItemModel.added_by_user),
|
||||||
|
joinedload(ItemModel.completed_by_user)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.order_by(ListModel.updated_at.desc())
|
.order_by(ListModel.updated_at.desc())
|
||||||
)
|
)
|
||||||
@ -123,7 +130,6 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
|
|||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||||
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
|
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
|
||||||
await transaction.rollback() # Rollback before raising
|
|
||||||
raise ConflictError(
|
raise ConflictError(
|
||||||
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
||||||
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
||||||
@ -153,23 +159,19 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
|
|||||||
updated_list = result.scalar_one_or_none()
|
updated_list = result.scalar_one_or_none()
|
||||||
|
|
||||||
if updated_list is None: # Should not happen
|
if updated_list is None: # Should not happen
|
||||||
await transaction.rollback()
|
|
||||||
raise ListOperationError("Failed to load list after update.")
|
raise ListOperationError("Failed to load list after update.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return updated_list
|
return updated_list
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
# Ensure rollback if not handled by context manager (though it should be)
|
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
|
||||||
if db.in_transaction(): await db.rollback()
|
|
||||||
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
if db.in_transaction(): await db.rollback()
|
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
||||||
except ConflictError:
|
except ConflictError:
|
||||||
# Already rolled back or will be by context manager if transaction was started here
|
|
||||||
raise
|
raise
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
if db.in_transaction(): await db.rollback()
|
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
||||||
|
|
||||||
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
||||||
@ -177,13 +179,11 @@ async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
|||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
|
||||||
await db.delete(list_db)
|
await db.delete(list_db)
|
||||||
await transaction.commit() # Explicit commit
|
|
||||||
# return None # Already implicitly returns None
|
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
# Rollback should be handled by the async with block on exception
|
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
# Rollback should be handled by the async with block on exception
|
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
||||||
|
|
||||||
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
||||||
|
@ -7,6 +7,7 @@ from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
|||||||
from decimal import Decimal, ROUND_HALF_UP
|
from decimal import Decimal, ROUND_HALF_UP
|
||||||
from typing import List as PyList, Optional, Sequence
|
from typing import List as PyList, Optional, Sequence
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import (
|
from app.models import (
|
||||||
Settlement as SettlementModel,
|
Settlement as SettlementModel,
|
||||||
@ -27,10 +28,12 @@ from app.core.exceptions import (
|
|||||||
ConflictError
|
ConflictError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
||||||
"""Creates a new settlement record."""
|
"""Creates a new settlement record."""
|
||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
||||||
if not payer:
|
if not payer:
|
||||||
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
||||||
@ -80,20 +83,21 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c
|
|||||||
loaded_settlement = result.scalar_one_or_none()
|
loaded_settlement = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_settlement is None:
|
if loaded_settlement is None:
|
||||||
await transaction.rollback() # Should be handled by context manager
|
|
||||||
raise SettlementOperationError("Failed to load settlement after creation.")
|
raise SettlementOperationError("Failed to load settlement after creation.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return loaded_settlement
|
return loaded_settlement
|
||||||
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||||
# These are validation errors, re-raise them.
|
# These are validation errors, re-raise them.
|
||||||
# If a transaction was started, context manager handles rollback.
|
# If a transaction was started, context manager handles rollback.
|
||||||
raise
|
raise
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
|
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@ -111,8 +115,10 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional
|
|||||||
)
|
)
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
# Optional: logger.warning or info if needed for read operations
|
||||||
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
# Optional: logger.warning or info if needed for read operations
|
||||||
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
|
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
|
||||||
|
|
||||||
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
||||||
@ -176,7 +182,7 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
|
|||||||
Assumes SettlementModel has version and updated_at fields.
|
Assumes SettlementModel has version and updated_at fields.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
# Ensure the settlement_db passed is managed by the current session if not already.
|
# Ensure the settlement_db passed is managed by the current session if not already.
|
||||||
# This is usually true if fetched by an endpoint dependency using the same session.
|
# This is usually true if fetched by an endpoint dependency using the same session.
|
||||||
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
|
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
|
||||||
@ -228,20 +234,21 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
|
|||||||
updated_settlement = result.scalar_one_or_none()
|
updated_settlement = result.scalar_one_or_none()
|
||||||
|
|
||||||
if updated_settlement is None: # Should not happen
|
if updated_settlement is None: # Should not happen
|
||||||
await transaction.rollback()
|
|
||||||
raise SettlementOperationError("Failed to load settlement after update.")
|
raise SettlementOperationError("Failed to load settlement after update.")
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return updated_settlement
|
return updated_settlement
|
||||||
except ConflictError as e: # ConflictError should be defined in exceptions
|
except ConflictError as e: # ConflictError should be defined in exceptions
|
||||||
raise
|
raise
|
||||||
except InvalidOperationError as e:
|
except InvalidOperationError as e:
|
||||||
raise
|
raise
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
|
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
|
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@ -251,7 +258,7 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
|
|||||||
Assumes SettlementModel has a version field.
|
Assumes SettlementModel has a version field.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
||||||
if expected_version is not None:
|
if expected_version is not None:
|
||||||
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
||||||
raise ConflictError( # Make sure ConflictError is defined
|
raise ConflictError( # Make sure ConflictError is defined
|
||||||
@ -260,12 +267,13 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex
|
|||||||
)
|
)
|
||||||
|
|
||||||
await db.delete(settlement_db)
|
await db.delete(settlement_db)
|
||||||
await transaction.commit()
|
|
||||||
except ConflictError as e: # ConflictError should be defined
|
except ConflictError as e: # ConflictError should be defined
|
||||||
raise
|
raise
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
|
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
|
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
|
||||||
|
|
||||||
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
|
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
|
||||||
|
@ -4,6 +4,7 @@ from sqlalchemy.future import select
|
|||||||
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import logging # Add logging import
|
||||||
|
|
||||||
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
|
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
|
||||||
from app.schemas.user import UserCreate
|
from app.schemas.user import UserCreate
|
||||||
@ -18,12 +19,13 @@ from app.core.exceptions import (
|
|||||||
UserOperationError # Add if specific user operation errors are needed
|
UserOperationError # Add if specific user operation errors are needed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # Initialize logger
|
||||||
|
|
||||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
||||||
"""Fetches a user from the database by email, with common relationships."""
|
"""Fetches a user from the database by email, with common relationships."""
|
||||||
try:
|
try:
|
||||||
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
|
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
|
||||||
# For a single select, it can be omitted if preferred, session handles connection.
|
# For a single select, it can be omitted if preferred, session handles connection.
|
||||||
async with db.begin(): # Or remove if only a single select operation
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(UserModel)
|
select(UserModel)
|
||||||
.filter(UserModel.email == email)
|
.filter(UserModel.email == email)
|
||||||
@ -36,8 +38,10 @@ async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]
|
|||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
return result.scalars().first()
|
return result.scalars().first()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error while fetching user by email '{email}': {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error while fetching user by email '{email}': {str(e)}", exc_info=True)
|
||||||
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
|
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
|
||||||
|
|
||||||
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
||||||
@ -67,19 +71,19 @@ async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
|||||||
loaded_user = result.scalar_one_or_none()
|
loaded_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if loaded_user is None:
|
if loaded_user is None:
|
||||||
await transaction.rollback() # Should be handled by context manager, but explicit
|
|
||||||
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
|
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
|
||||||
|
|
||||||
await transaction.commit()
|
|
||||||
return loaded_user
|
return loaded_user
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
# Context manager handles rollback on error
|
logger.error(f"Database integrity error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
|
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
|
||||||
raise EmailAlreadyRegisteredError(email=user_in.email)
|
raise EmailAlreadyRegisteredError(email=user_in.email)
|
||||||
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
|
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
|
logger.error(f"Database connection error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
|
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True)
|
||||||
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
|
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
|
||||||
|
|
||||||
# Ensure UserOperationError is defined in app.core.exceptions if used
|
# Ensure UserOperationError is defined in app.core.exceptions if used
|
||||||
|
Loading…
Reference in New Issue
Block a user