from typing import List, Optional from decimal import Decimal from datetime import datetime, timezone from sqlalchemy import select, func, update, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, joinedload from app.models import ( SettlementActivity, ExpenseSplit, Expense, User, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum, ) # Placeholder for Pydantic schema - actual schema definition is a later step # from app.schemas.settlement_activity import SettlementActivityCreate # Assuming this path from pydantic import BaseModel # Using pydantic BaseModel directly for the placeholder class SettlementActivityCreatePlaceholder(BaseModel): expense_split_id: int paid_by_user_id: int amount_paid: Decimal paid_at: Optional[datetime] = None class Config: orm_mode = True # Pydantic V1 style orm_mode # from_attributes = True # Pydantic V2 style async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]: """ Updates the status of an ExpenseSplit based on its settlement activities. Also updates the overall status of the parent Expense. """ # Fetch the ExpenseSplit with its related settlement_activities and the parent expense result = await db.execute( select(ExpenseSplit) .options( selectinload(ExpenseSplit.settlement_activities), joinedload(ExpenseSplit.expense) # To get expense_id easily ) .where(ExpenseSplit.id == expense_split_id) ) expense_split = result.scalar_one_or_none() if not expense_split: # Or raise an exception, depending on desired error handling return None # Calculate total_paid from all settlement_activities for that split total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities) total_paid = Decimal(total_paid).quantize(Decimal("0.01")) # Ensure two decimal places # Compare total_paid with ExpenseSplit.owed_amount if total_paid >= expense_split.owed_amount: expense_split.status = ExpenseSplitStatusEnum.paid # Set paid_at to the latest relevant SettlementActivity or current time # For simplicity, let's find the latest paid_at from activities, or use now() latest_paid_at = None if expense_split.settlement_activities: latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at) expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc) elif total_paid > 0: expense_split.status = ExpenseSplitStatusEnum.partially_paid expense_split.paid_at = None # Clear paid_at if not fully paid else: # total_paid == 0 expense_split.status = ExpenseSplitStatusEnum.unpaid expense_split.paid_at = None # Clear paid_at await db.flush() await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) # Refresh to get updated data and related expense return expense_split async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]: """ Updates the overall_status of an Expense based on the status of its splits. """ # Fetch the Expense with its related splits result = await db.execute( select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id) ) expense = result.scalar_one_or_none() if not expense: # Or raise an exception return None if not expense.splits: # No splits, should not happen for a valid expense but handle defensively expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid # Or some other default/error state await db.flush() await db.refresh(expense) return expense num_splits = len(expense.splits) num_paid_splits = 0 num_partially_paid_splits = 0 num_unpaid_splits = 0 for split in expense.splits: if split.status == ExpenseSplitStatusEnum.paid: num_paid_splits += 1 elif split.status == ExpenseSplitStatusEnum.partially_paid: num_partially_paid_splits += 1 else: # unpaid num_unpaid_splits += 1 if num_paid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.paid elif num_unpaid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid else: # Mix of paid, partially_paid, or unpaid but not all unpaid/paid expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid await db.flush() await db.refresh(expense, attribute_names=['overall_settlement_status']) return expense async def create_settlement_activity( db: AsyncSession, settlement_activity_in: SettlementActivityCreatePlaceholder, current_user_id: int ) -> Optional[SettlementActivity]: """ Creates a new settlement activity, then updates the parent expense split and expense statuses. """ # Validate ExpenseSplit split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id)) expense_split = split_result.scalar_one_or_none() if not expense_split: # Consider raising an HTTPException in an API layer return None # ExpenseSplit not found # Validate User (paid_by_user_id) user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id)) paid_by_user = user_result.scalar_one_or_none() if not paid_by_user: return None # User not found # Create SettlementActivity instance db_settlement_activity = SettlementActivity( expense_split_id=settlement_activity_in.expense_split_id, paid_by_user_id=settlement_activity_in.paid_by_user_id, amount_paid=settlement_activity_in.amount_paid, paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc), created_by_user_id=current_user_id # The user recording the activity ) db.add(db_settlement_activity) await db.flush() # Flush to get the ID for db_settlement_activity # Update statuses updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id) if updated_split and updated_split.expense_id: await update_expense_overall_status(db, expense_id=updated_split.expense_id) else: # This case implies update_expense_split_status returned None or expense_id was missing. # This could be a problem, consider logging or raising an error. # For now, the transaction would roll back if an exception is raised. # If not raising, the overall status update might be skipped. pass # Or handle error await db.refresh(db_settlement_activity, attribute_names=['split', 'payer', 'creator']) # Refresh to load relationships return db_settlement_activity async def get_settlement_activity_by_id( db: AsyncSession, settlement_activity_id: int ) -> Optional[SettlementActivity]: """ Fetches a single SettlementActivity by its ID, loading relationships. """ result = await db.execute( select(SettlementActivity) .options( selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), # Load split and its parent expense selectinload(SettlementActivity.payer), # Load the user who paid selectinload(SettlementActivity.creator) # Load the user who created the record ) .where(SettlementActivity.id == settlement_activity_id) ) return result.scalar_one_or_none() async def get_settlement_activities_for_split( db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100 ) -> List[SettlementActivity]: """ Fetches a list of SettlementActivity records associated with a given expense_split_id. """ result = await db.execute( select(SettlementActivity) .where(SettlementActivity.expense_split_id == expense_split_id) .options( selectinload(SettlementActivity.payer), # Load the user who paid selectinload(SettlementActivity.creator) # Load the user who created the record ) .order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc()) .offset(skip) .limit(limit) ) return result.scalars().all() # Further CRUD operations like update/delete can be added later if needed.