# app/crud/settlement.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_ from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from decimal import Decimal, ROUND_HALF_UP from typing import List as PyList, Optional, Sequence from datetime import datetime, timezone import logging # Add logging import from app.models import ( Settlement as SettlementModel, User as UserModel, Group as GroupModel, UserGroup as UserGroupModel ) from app.schemas.expense import SettlementCreate, SettlementUpdate from app.core.exceptions import ( UserNotFoundError, GroupNotFoundError, InvalidOperationError, DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, SettlementOperationError, ConflictError ) logger = logging.getLogger(__name__) # Initialize logger async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: """Creates a new settlement record.""" try: async with db.begin_nested() if db.in_transaction() else db.begin(): payer = await db.get(UserModel, settlement_in.paid_by_user_id) if not payer: raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") payee = await db.get(UserModel, settlement_in.paid_to_user_id) if not payee: raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: raise InvalidOperationError("Payer and Payee cannot be the same user.") group = await db.get(GroupModel, settlement_in.group_id) if not group: raise GroupNotFoundError(settlement_in.group_id) # Permission check example (can be in API layer too) # if current_user_id not in [payer.id, payee.id]: # is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1) # is_member_result = await db.execute(is_member_stmt) # if not is_member_result.scalar_one_or_none(): # raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.") db_settlement = SettlementModel( group_id=settlement_in.group_id, paid_by_user_id=settlement_in.paid_by_user_id, paid_to_user_id=settlement_in.paid_to_user_id, amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), description=settlement_in.description, created_by_user_id=current_user_id ) db.add(db_settlement) await db.flush() # Re-fetch with relationships stmt = ( select(SettlementModel) .where(SettlementModel.id == db_settlement.id) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group), selectinload(SettlementModel.created_by_user) ) ) result = await db.execute(stmt) loaded_settlement = result.scalar_one_or_none() if loaded_settlement is None: raise SettlementOperationError("Failed to load settlement after creation.") return loaded_settlement except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e: # These are validation errors, re-raise them. # If a transaction was started, context manager handles rollback. raise except IntegrityError as e: logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True) raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}") except OperationalError as e: logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}") async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: try: result = await db.execute( select(SettlementModel) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group), selectinload(SettlementModel.created_by_user) ) .where(SettlementModel.id == settlement_id) ) return result.scalars().first() except OperationalError as e: # Optional: logger.warning or info if needed for read operations raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}") except SQLAlchemyError as e: # Optional: logger.warning or info if needed for read operations raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}") async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: try: result = await db.execute( select(SettlementModel) .where(SettlementModel.group_id == group_id) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group), selectinload(SettlementModel.created_by_user) ) ) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}") async def get_settlements_involving_user( db: AsyncSession, user_id: int, group_id: Optional[int] = None, skip: int = 0, limit: int = 100 ) -> Sequence[SettlementModel]: try: query = ( select(SettlementModel) .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group), selectinload(SettlementModel.created_by_user) ) ) if group_id: query = query.where(SettlementModel.group_id == group_id) result = await db.execute(query) return result.scalars().all() except OperationalError as e: raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}") async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: """ Updates an existing settlement. Only allows updates to description and settlement_date. Requires version matching for optimistic locking. Assumes SettlementUpdate schema includes a version field. Assumes SettlementModel has version and updated_at fields. """ try: async with db.begin_nested() if db.in_transaction() else db.begin(): # Ensure the settlement_db passed is managed by the current session if not already. # This is usually true if fetched by an endpoint dependency using the same session. # If not, `db.add(settlement_db)` might be needed before modification if it's detached. if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'): raise InvalidOperationError("Version field is missing in model or input for optimistic locking.") if settlement_db.version != settlement_in.version: raise ConflictError( # Make sure ConflictError is defined in exceptions f"Settlement (ID: {settlement_db.id}) has been modified. " f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh." ) update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) allowed_to_update = {"description", "settlement_date"} updated_something = False for field, value in update_data.items(): if field in allowed_to_update: setattr(settlement_db, field, value) updated_something = True # Silently ignore fields not allowed to update or raise error: # else: # raise InvalidOperationError(f"Field '{field}' cannot be updated.") if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): # No updatable fields were actually provided, or they didn't change # Still, we might want to return the re-loaded settlement if version matched. pass settlement_db.version += 1 settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field db.add(settlement_db) # Mark as dirty await db.flush() # Re-fetch with relationships stmt = ( select(SettlementModel) .where(SettlementModel.id == settlement_db.id) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group), selectinload(SettlementModel.created_by_user) ) ) result = await db.execute(stmt) updated_settlement = result.scalar_one_or_none() if updated_settlement is None: # Should not happen raise SettlementOperationError("Failed to load settlement after update.") return updated_settlement except ConflictError as e: # ConflictError should be defined in exceptions raise except InvalidOperationError as e: raise except IntegrityError as e: logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True) raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}") except OperationalError as e: logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}") async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: """ Deletes a settlement. Requires version matching if expected_version is provided. Assumes SettlementModel has a version field. """ try: async with db.begin_nested() if db.in_transaction() else db.begin(): if expected_version is not None: if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: raise ConflictError( # Make sure ConflictError is defined f"Settlement (ID: {settlement_db.id}) cannot be deleted. " f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh." ) await db.delete(settlement_db) except ConflictError as e: # ConflictError should be defined raise except OperationalError as e: logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}") # Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions # Example: class SettlementOperationError(AppException): pass # Example: class ConflictError(AppException): status_code = 409