281 lines
13 KiB
Python
281 lines
13 KiB
Python
# app/crud/settlement.py
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from sqlalchemy import or_
|
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from typing import List as PyList, Optional, Sequence
|
|
from datetime import datetime, timezone
|
|
import logging # Add logging import
|
|
|
|
from app.models import (
|
|
Settlement as SettlementModel,
|
|
User as UserModel,
|
|
Group as GroupModel,
|
|
UserGroup as UserGroupModel
|
|
)
|
|
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
|
from app.core.exceptions import (
|
|
UserNotFoundError,
|
|
GroupNotFoundError,
|
|
InvalidOperationError,
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
SettlementOperationError,
|
|
ConflictError
|
|
)
|
|
|
|
logger = logging.getLogger(__name__) # Initialize logger
|
|
|
|
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
|
"""Creates a new settlement record."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
|
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
|
if not payer:
|
|
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
|
|
|
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
|
|
if not payee:
|
|
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
|
|
|
|
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
|
|
raise InvalidOperationError("Payer and Payee cannot be the same user.")
|
|
|
|
group = await db.get(GroupModel, settlement_in.group_id)
|
|
if not group:
|
|
raise GroupNotFoundError(settlement_in.group_id)
|
|
|
|
# Permission check example (can be in API layer too)
|
|
# if current_user_id not in [payer.id, payee.id]:
|
|
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
|
|
# is_member_result = await db.execute(is_member_stmt)
|
|
# if not is_member_result.scalar_one_or_none():
|
|
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
|
|
|
|
db_settlement = SettlementModel(
|
|
group_id=settlement_in.group_id,
|
|
paid_by_user_id=settlement_in.paid_by_user_id,
|
|
paid_to_user_id=settlement_in.paid_to_user_id,
|
|
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
|
|
description=settlement_in.description,
|
|
created_by_user_id=current_user_id
|
|
)
|
|
db.add(db_settlement)
|
|
await db.flush()
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(SettlementModel)
|
|
.where(SettlementModel.id == db_settlement.id)
|
|
.options(
|
|
selectinload(SettlementModel.payer),
|
|
selectinload(SettlementModel.payee),
|
|
selectinload(SettlementModel.group),
|
|
selectinload(SettlementModel.created_by_user)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_settlement = result.scalar_one_or_none()
|
|
|
|
if loaded_settlement is None:
|
|
raise SettlementOperationError("Failed to load settlement after creation.")
|
|
|
|
return loaded_settlement
|
|
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
|
# These are validation errors, re-raise them.
|
|
# If a transaction was started, context manager handles rollback.
|
|
raise
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during settlement creation: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error during settlement creation: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
|
|
|
|
|
|
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
|
|
try:
|
|
result = await db.execute(
|
|
select(SettlementModel)
|
|
.options(
|
|
selectinload(SettlementModel.payer),
|
|
selectinload(SettlementModel.payee),
|
|
selectinload(SettlementModel.group),
|
|
selectinload(SettlementModel.created_by_user)
|
|
)
|
|
.where(SettlementModel.id == settlement_id)
|
|
)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
# Optional: logger.warning or info if needed for read operations
|
|
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
# Optional: logger.warning or info if needed for read operations
|
|
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
|
|
|
|
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
|
try:
|
|
result = await db.execute(
|
|
select(SettlementModel)
|
|
.where(SettlementModel.group_id == group_id)
|
|
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(
|
|
selectinload(SettlementModel.payer),
|
|
selectinload(SettlementModel.payee),
|
|
selectinload(SettlementModel.group),
|
|
selectinload(SettlementModel.created_by_user)
|
|
)
|
|
)
|
|
return result.scalars().all()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
|
|
|
|
|
|
async def get_settlements_involving_user(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
group_id: Optional[int] = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> Sequence[SettlementModel]:
|
|
try:
|
|
query = (
|
|
select(SettlementModel)
|
|
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
|
|
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
|
.offset(skip).limit(limit)
|
|
.options(
|
|
selectinload(SettlementModel.payer),
|
|
selectinload(SettlementModel.payee),
|
|
selectinload(SettlementModel.group),
|
|
selectinload(SettlementModel.created_by_user)
|
|
)
|
|
)
|
|
if group_id:
|
|
query = query.where(SettlementModel.group_id == group_id)
|
|
|
|
result = await db.execute(query)
|
|
return result.scalars().all()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
|
|
|
|
|
|
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
|
|
"""
|
|
Updates an existing settlement.
|
|
Only allows updates to description and settlement_date.
|
|
Requires version matching for optimistic locking.
|
|
Assumes SettlementUpdate schema includes a version field.
|
|
Assumes SettlementModel has version and updated_at fields.
|
|
"""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
|
# Ensure the settlement_db passed is managed by the current session if not already.
|
|
# This is usually true if fetched by an endpoint dependency using the same session.
|
|
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
|
|
|
|
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
|
|
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
|
|
|
|
if settlement_db.version != settlement_in.version:
|
|
raise ConflictError( # Make sure ConflictError is defined in exceptions
|
|
f"Settlement (ID: {settlement_db.id}) has been modified. "
|
|
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
|
|
)
|
|
|
|
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
|
|
allowed_to_update = {"description", "settlement_date"}
|
|
updated_something = False
|
|
|
|
for field, value in update_data.items():
|
|
if field in allowed_to_update:
|
|
setattr(settlement_db, field, value)
|
|
updated_something = True
|
|
# Silently ignore fields not allowed to update or raise error:
|
|
# else:
|
|
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
|
|
|
|
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
|
|
# No updatable fields were actually provided, or they didn't change
|
|
# Still, we might want to return the re-loaded settlement if version matched.
|
|
pass
|
|
|
|
settlement_db.version += 1
|
|
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
|
|
|
|
db.add(settlement_db) # Mark as dirty
|
|
await db.flush()
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(SettlementModel)
|
|
.where(SettlementModel.id == settlement_db.id)
|
|
.options(
|
|
selectinload(SettlementModel.payer),
|
|
selectinload(SettlementModel.payee),
|
|
selectinload(SettlementModel.group),
|
|
selectinload(SettlementModel.created_by_user)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
updated_settlement = result.scalar_one_or_none()
|
|
|
|
if updated_settlement is None: # Should not happen
|
|
raise SettlementOperationError("Failed to load settlement after update.")
|
|
|
|
return updated_settlement
|
|
except ConflictError as e: # ConflictError should be defined in exceptions
|
|
raise
|
|
except InvalidOperationError as e:
|
|
raise
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during settlement update: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during settlement update: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error during settlement update: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
|
|
|
|
|
|
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
|
|
"""
|
|
Deletes a settlement. Requires version matching if expected_version is provided.
|
|
Assumes SettlementModel has a version field.
|
|
"""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin():
|
|
if expected_version is not None:
|
|
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
|
raise ConflictError( # Make sure ConflictError is defined
|
|
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
|
|
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
|
|
)
|
|
|
|
await db.delete(settlement_db)
|
|
except ConflictError as e: # ConflictError should be defined
|
|
raise
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
|
|
|
|
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
|
|
# Example: class SettlementOperationError(AppException): pass
|
|
# Example: class ConflictError(AppException): status_code = 409 |