from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from typing import List as PyList, Sequence, Optional from decimal import Decimal, ROUND_HALF_UP from app.models import ( Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, ExpenseRecord as ExpenseRecordModel, ExpenseShare as ExpenseShareModel, SettlementActivity as SettlementActivityModel, SplitTypeEnum, ) async def get_priced_items_for_list(db: AsyncSession, list_id: int) -> Sequence[ItemModel]: result = await db.execute(select(ItemModel).where(ItemModel.list_id == list_id, ItemModel.price.is_not(None))) return result.scalars().all() async def get_group_member_ids(db: AsyncSession, group_id: int) -> PyList[int]: result = await db.execute(select(UserModel.user_id).where(UserGroupModel.group_id == group_id)) return result.scalars().all() async def create_expense_record_and_shares( db: AsyncSession, list_id: int, calculated_by_id: int, total_amount: Decimal, participant_ids: PyList[int], split_type: SplitTypeEnum = SplitTypeEnum.equal ) -> ExpenseRecordModel: if not participant_ids or total_amount <= Decimal("0.00"): raise ValueError("Invalid participants or total amount.") db_expense_record = ExpenseRecordModel( list_id=list_id, calculated_by_id=calculated_by_id, total_amount=total_amount, participants=participant_ids, split_type=split_type, is_settled=False ) db.add(db_expense_record) await db.flush() num_participants = len(participant_ids) individual_share = (total_amount / Decimal(num_participants)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) total_calculated = individual_share * (num_participants - 1) last_share = total_amount - total_calculated shares_to_add = [ ExpenseShareModel(expense_record_id=db_expense_record.id, user_id=user_id, amount_owed=(last_share if i == num_participants - 1 else individual_share), is_paid=False) for i, user_id in enumerate(participant_ids) ] db.add_all(shares_to_add) await db.commit() await db.refresh(db_expense_record, attribute_names=['shares']) return db_expense_record # Fetch all expense records for a list async def get_expense_records_for_list(db: AsyncSession, list_id: int) -> Sequence[ExpenseRecordModel]: result = await db.execute( select(ExpenseRecordModel) .where(ExpenseRecordModel.list_id == list_id) .options( selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user), selectinload(ExpenseRecordModel.settlement_activities) ) .order_by(ExpenseRecordModel.calculated_at.desc()) ) return result.scalars().unique().all() # Fetch a specific expense record by ID async def get_expense_record_by_id(db: AsyncSession, record_id: int) -> Optional[ExpenseRecordModel]: result = await db.execute( select(ExpenseRecordModel) .where(ExpenseRecordModel.id == record_id) .options( selectinload(ExpenseRecordModel.shares).selectinload(ExpenseShareModel.user), selectinload(ExpenseRecordModel.settlement_activities).options( joinedload(SettlementActivityModel.payer), joinedload(SettlementActivityModel.affected_user) ) ) ) return result.scalars().first()