88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from typing import List as PyList, Sequence, Optional
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
|
|
from app.models import (
|
|
Item as ItemModel,
|
|
User as UserModel,
|
|
UserGroup as UserGroupModel,
|
|
ExpenseRecord as ExpenseRecordModel,
|
|
ExpenseShare as ExpenseShareModel,
|
|
SettlementActivity as SettlementActivityModel,
|
|
SplitTypeEnum,
|
|
)
|
|
|
|
async def get_priced_items_for_list(db: AsyncSession, list_id: int) -> Sequence[ItemModel]:
|
|
result = await db.execute(select(ItemModel).where(ItemModel.list_id == list_id, ItemModel.price.is_not(None)))
|
|
return result.scalars().all()
|
|
|
|
async def get_group_member_ids(db: AsyncSession, group_id: int) -> PyList[int]:
|
|
result = await db.execute(select(UserModel.user_id).where(UserGroupModel.group_id == group_id))
|
|
return result.scalars().all()
|
|
|
|
async def create_expense_record_and_shares(
|
|
db: AsyncSession,
|
|
list_id: int,
|
|
calculated_by_id: int,
|
|
total_amount: Decimal,
|
|
participant_ids: PyList[int],
|
|
split_type: SplitTypeEnum = SplitTypeEnum.equal
|
|
) -> ExpenseRecordModel:
|
|
if not participant_ids or total_amount <= Decimal("0.00"):
|
|
raise ValueError("Invalid participants or total amount.")
|
|
|
|
db_expense_record = ExpenseRecordModel(
|
|
list_id=list_id,
|
|
calculated_by_id=calculated_by_id,
|
|
total_amount=total_amount,
|
|
participants=participant_ids,
|
|
split_type=split_type,
|
|
is_settled=False
|
|
)
|
|
db.add(db_expense_record)
|
|
await db.flush()
|
|
|
|
num_participants = len(participant_ids)
|
|
individual_share = (total_amount / Decimal(num_participants)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
total_calculated = individual_share * (num_participants - 1)
|
|
last_share = total_amount - total_calculated
|
|
|
|
shares_to_add = [
|
|
ExpenseShareModel(expense_record_id=db_expense_record.id, user_id=user_id, amount_owed=(last_share if i == num_participants - 1 else individual_share), is_paid=False)
|
|
for i, user_id in enumerate(participant_ids)
|
|
]
|
|
|
|
db.add_all(shares_to_add)
|
|
await db.commit()
|
|
await db.refresh(db_expense_record, attribute_names=['shares'])
|
|
return db_expense_record
|
|
|
|
# Fetch all expense records for a list
|
|
async def get_expense_records_for_list(db: AsyncSession, list_id: int) -> Sequence[ExpenseRecordModel]:
|
|
result = await db.execute(
|
|
select(ExpenseRecordModel)
|
|
.where(ExpenseRecordModel.list_id == list_id)
|
|
.options(
|
|
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
|
|
selectinload(ExpenseRecordModel.settlement_activities)
|
|
)
|
|
.order_by(ExpenseRecordModel.calculated_at.desc())
|
|
)
|
|
return result.scalars().unique().all()
|
|
|
|
# Fetch a specific expense record by ID
|
|
async def get_expense_record_by_id(db: AsyncSession, record_id: int) -> Optional[ExpenseRecordModel]:
|
|
result = await db.execute(
|
|
select(ExpenseRecordModel)
|
|
.where(ExpenseRecordModel.id == record_id)
|
|
.options(
|
|
selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user),
|
|
selectinload(ExpenseRecordModel.settlement_activities).options(
|
|
joinedload(SettlementActivityModel.payer),
|
|
joinedload(SettlementActivityModel.affected_user)
|
|
)
|
|
)
|
|
)
|
|
return result.scalars().first() |