doe/be/app/crud/expense.py
2025-04-03 01:24:23 +02:00

88 lines
3.4 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from typing import List as PyList, Sequence, Optional
from decimal import Decimal, ROUND_HALF_UP
from app.models import (
Item as ItemModel,
User as UserModel,
UserGroup as UserGroupModel,
ExpenseRecord as ExpenseRecordModel,
ExpenseShare as ExpenseShareModel,
SettlementActivity as SettlementActivityModel,
SplitTypeEnum,
)
async def get_priced_items_for_list(db: AsyncSession, list_id: int) -> Sequence[ItemModel]:
result = await db.execute(select(ItemModel).where(ItemModel.list_id == list_id, ItemModel.price.is_not(None)))
return result.scalars().all()
async def get_group_member_ids(db: AsyncSession, group_id: int) -> PyList[int]:
result = await db.execute(select(UserModel.user_id).where(UserGroupModel.group_id == group_id))
return result.scalars().all()
async def create_expense_record_and_shares(
db: AsyncSession,
list_id: int,
calculated_by_id: int,
total_amount: Decimal,
participant_ids: PyList[int],
split_type: SplitTypeEnum = SplitTypeEnum.equal
) -> ExpenseRecordModel:
if not participant_ids or total_amount <= Decimal("0.00"):
raise ValueError("Invalid participants or total amount.")
db_expense_record = ExpenseRecordModel(
list_id=list_id,
calculated_by_id=calculated_by_id,
total_amount=total_amount,
participants=participant_ids,
split_type=split_type,
is_settled=False
)
db.add(db_expense_record)
await db.flush()
num_participants = len(participant_ids)
individual_share = (total_amount / Decimal(num_participants)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
total_calculated = individual_share * (num_participants - 1)
last_share = total_amount - total_calculated
shares_to_add = [
ExpenseShareModel(expense_record_id=db_expense_record.id, user_id=user_id, amount_owed=(last_share if i == num_participants - 1 else individual_share), is_paid=False)
for i, user_id in enumerate(participant_ids)
]
db.add_all(shares_to_add)
await db.commit()
await db.refresh(db_expense_record, attribute_names=['shares'])
return db_expense_record
# Fetch all expense records for a list
async def get_expense_records_for_list(db: AsyncSession, list_id: int) -> Sequence[ExpenseRecordModel]:
result = await db.execute(
select(ExpenseRecordModel)
.where(ExpenseRecordModel.list_id == list_id)
.options(
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
selectinload(ExpenseRecordModel.settlement_activities)
)
.order_by(ExpenseRecordModel.calculated_at.desc())
)
return result.scalars().unique().all()
# Fetch a specific expense record by ID
async def get_expense_record_by_id(db: AsyncSession, record_id: int) -> Optional[ExpenseRecordModel]:
result = await db.execute(
select(ExpenseRecordModel)
.where(ExpenseRecordModel.id == record_id)
.options(
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
selectinload(ExpenseRecordModel.settlement_activities).options(
joinedload(SettlementActivityModel.payer),
joinedload(SettlementActivityModel.affected_user)
)
)
)
return result.scalars().first()