mitlist/be/app/crud/settlement.py
2025-05-08 00:56:26 +02:00

168 lines
7.5 KiB
Python

# app/crud/settlement.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_
from decimal import Decimal
from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone
from app.models import (
Settlement as SettlementModel,
User as UserModel,
Group as GroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record."""
# Validate Payer, Payee, and Group exist
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Optional: Check if current_user_id is part of the group or is one of the parties involved
# This is more of an API-level permission check but could be added here if strict.
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
# if not is_in_group.first():
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
try:
await db.commit()
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
return db_settlement
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
.where(SettlementModel.id == settlement_id)
)
return result.scalars().first()
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
result = await db.execute(
select(SettlementModel)
.where(SettlementModel.group_id == group_id)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
)
return result.scalars().all()
async def get_settlements_involving_user(
db: AsyncSession,
user_id: int,
group_id: Optional[int] = None,
skip: int = 0,
limit: int = 100
) -> Sequence[SettlementModel]:
query = (
select(SettlementModel)
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.offset(skip).limit(limit)
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
"""
Updates an existing settlement.
Only allows updates to description and settlement_date.
Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field.
"""
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version does not match current version {settlement_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
else:
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
pass # No actual updatable fields provided, but version matched.
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
settlement_db.updated_at = datetime.now(timezone.utc)
try:
await db.commit()
await db.refresh(settlement_db)
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
return settlement_db
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
"""
Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field.
"""
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(settlement_db)
try:
await db.commit()
except Exception as e:
await db.rollback()
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
return None
# TODO: Implement update_settlement (consider restrictions, versioning)
# TODO: Implement delete_settlement (consider implications on balances)