# app/crud/settlement.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_ from decimal import Decimal from typing import List as PyList, Optional, Sequence from datetime import datetime, timezone from app.models import ( Settlement as SettlementModel, User as UserModel, Group as GroupModel ) from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: """Creates a new settlement record.""" # Validate Payer, Payee, and Group exist payer = await db.get(UserModel, settlement_in.paid_by_user_id) if not payer: raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer") payee = await db.get(UserModel, settlement_in.paid_to_user_id) if not payee: raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee") if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id: raise InvalidOperationError("Payer and Payee cannot be the same user.") group = await db.get(GroupModel, settlement_in.group_id) if not group: raise GroupNotFoundError(settlement_in.group_id) # Optional: Check if current_user_id is part of the group or is one of the parties involved # This is more of an API-level permission check but could be added here if strict. # For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]: # is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id)) # if not is_in_group.first(): # raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.") db_settlement = SettlementModel( group_id=settlement_in.group_id, paid_by_user_id=settlement_in.paid_by_user_id, paid_to_user_id=settlement_in.paid_to_user_id, amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc), description=settlement_in.description # created_by_user_id = current_user_id # Optional: Who recorded this settlement ) db.add(db_settlement) try: await db.commit() await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"]) except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to save settlement: {str(e)}") return db_settlement async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: result = await db.execute( select(SettlementModel) .options( selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group) ) .where(SettlementModel.id == settlement_id) ) return result.scalars().first() async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: result = await db.execute( select(SettlementModel) .where(SettlementModel.group_id == group_id) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee)) ) return result.scalars().all() async def get_settlements_involving_user( db: AsyncSession, user_id: int, group_id: Optional[int] = None, skip: int = 0, limit: int = 100 ) -> Sequence[SettlementModel]: query = ( select(SettlementModel) .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .offset(skip).limit(limit) .options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group)) ) if group_id: query = query.where(SettlementModel.group_id == group_id) result = await db.execute(query) return result.scalars().all() async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: """ Updates an existing settlement. Only allows updates to description and settlement_date. Requires version matching for optimistic locking. Assumes SettlementUpdate schema includes a version field. """ # Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently. if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version: raise InvalidOperationError( f"Settlement (ID: {settlement_db.id}) has been modified. " f"Your version does not match current version {settlement_db.version}. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"}) allowed_to_update = {"description", "settlement_date"} updated_something = False for field, value in update_data.items(): if field in allowed_to_update: setattr(settlement_db, field, value) updated_something = True else: raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.") if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): pass # No actual updatable fields provided, but version matched. settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing settlement_db.updated_at = datetime.now(timezone.utc) try: await db.commit() await db.refresh(settlement_db) except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to update settlement: {str(e)}") return settlement_db async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: """ Deletes a settlement. Requires version matching if expected_version is provided. Assumes SettlementModel has a version field. """ if expected_version is not None: if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: raise InvalidOperationError( f"Settlement (ID: {settlement_db.id}) cannot be deleted. " f"Expected version {expected_version} does not match current version. Please refresh.", # status_code=status.HTTP_409_CONFLICT ) await db.delete(settlement_db) try: await db.commit() except Exception as e: await db.rollback() raise InvalidOperationError(f"Failed to delete settlement: {str(e)}") return None