Refactor database session management and exception handling across CRUD operations; streamline transaction handling in expense, group, invite, item, list, settlement, and user modules for improved reliability and clarity. Introduce specific operation errors for better error reporting.

This commit is contained in:
mohamad 2025-05-16 21:54:29 +02:00
parent 515534dcce
commit 7a88ea258a
9 changed files with 850 additions and 459 deletions

View File

@ -128,6 +128,14 @@ class DatabaseQueryError(HTTPException):
detail=detail detail=detail
) )
class ExpenseOperationError(HTTPException):
"""Raised when an expense operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class OCRServiceUnavailableError(HTTPException): class OCRServiceUnavailableError(HTTPException):
"""Raised when the OCR service is unavailable.""" """Raised when the OCR service is unavailable."""
def __init__(self): def __init__(self):
@ -240,6 +248,22 @@ class ListStatusNotFoundError(HTTPException):
detail=f"Status for list {list_id} not found" detail=f"Status for list {list_id} not found"
) )
class InviteOperationError(HTTPException):
"""Raised when an invite operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class SettlementOperationError(HTTPException):
"""Raised when a settlement operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ConflictError(HTTPException): class ConflictError(HTTPException):
"""Raised when an optimistic lock version conflict occurs.""" """Raised when an optimistic lock version conflict occurs."""
def __init__(self, detail: str): def __init__(self, detail: str):
@ -283,3 +307,19 @@ class JWTUnexpectedError(HTTPException):
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error), detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""} headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
) )
class ListOperationError(HTTPException):
"""Raised when a list operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)
class ItemOperationError(HTTPException):
"""Raised when an item operation fails."""
def __init__(self, detail: str):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail
)

View File

@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone from datetime import datetime, timezone # Added timezone
from app.models import ( from app.models import (
@ -23,7 +23,12 @@ from app.core.exceptions import (
ListNotFoundError, ListNotFoundError,
GroupNotFoundError, GroupNotFoundError,
UserNotFoundError, UserNotFoundError,
InvalidOperationError # Import the new exception InvalidOperationError, # Import the new exception
DatabaseConnectionError, # Added
DatabaseIntegrityError, # Added
DatabaseQueryError, # Added
DatabaseTransactionError,# Added
ExpenseOperationError # Added specific exception
) )
# Placeholder for InvalidOperationError if not defined in app.core.exceptions # Placeholder for InvalidOperationError if not defined in app.core.exceptions
@ -108,60 +113,97 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
GroupNotFoundError: If specified group doesn't exist GroupNotFoundError: If specified group doesn't exist
InvalidOperationError: For various validation failures InvalidOperationError: For various validation failures
""" """
# Helper function to round decimals consistently
def round_money(amount: Decimal) -> Decimal:
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
# 1. Context Validation
# Validate basic context requirements first
if not expense_in.list_id and not expense_in.group_id:
raise InvalidOperationError("Expense must be associated with a list or a group.")
# 2. User Validation
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
# 3. List/Group Context Resolution
final_group_id = await _resolve_expense_context(db, expense_in)
# 4. Create the expense object
db_expense = ExpenseModel(
description=expense_in.description,
total_amount=round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id,
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id # Track who created this expense
)
# 5. Generate splits based on split type
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
# 6. Single transaction for expense and all splits
try: try:
db.add(db_expense) async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.flush() # Get expense ID without committing # 1. Validate payer
payer = await db.get(UserModel, expense_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
# Update all splits with the expense ID # 2. Context Resolution and Validation (now part of the transaction)
for split in splits_to_create: if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
split.expense_id = db_expense.id raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
db.add_all(splits_to_create) final_group_id = await _resolve_expense_context(db, expense_in)
await db.commit() # Further validation for item_id if provided
db_item_instance = None
if expense_in.item_id:
db_item_instance = await db.get(ItemModel, expense_in.item_id)
if not db_item_instance:
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
# Potentially link item's list/group if not already set on expense_in
if db_item_instance.list_id and not expense_in.list_id:
expense_in.list_id = db_item_instance.list_id
# Re-resolve context if list_id was derived from item
final_group_id = await _resolve_expense_context(db, expense_in)
except Exception as e: # 3. Create the ExpenseModel instance
await db.rollback() db_expense = ExpenseModel(
logger.error(f"Failed to save expense: {str(e)}", exc_info=True) description=expense_in.description,
raise InvalidOperationError(f"Failed to save expense: {str(e)}") total_amount=_round_money(expense_in.total_amount),
currency=expense_in.currency or "USD",
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
split_type=expense_in.split_type,
list_id=expense_in.list_id,
group_id=final_group_id, # Use resolved group_id
item_id=expense_in.item_id,
paid_by_user_id=expense_in.paid_by_user_id,
created_by_user_id=current_user_id
)
db.add(db_expense)
await db.flush() # Get expense ID
# Refresh to get the splits relationship populated # 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
await db.refresh(db_expense, attribute_names=["splits"]) splits_to_create = await _generate_expense_splits(
return db_expense db=db,
expense_model=db_expense,
expense_in=expense_in,
current_user_id=current_user_id # Pass for item-based splits needing creator info
)
for split_model in splits_to_create:
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
db.add_all(splits_to_create)
await db.flush() # Persist splits
# 5. Re-fetch the expense with all necessary relationships for the response
stmt = (
select(ExpenseModel)
.where(ExpenseModel.id == db_expense.id)
.options(
selectinload(ExpenseModel.paid_by_user),
selectinload(ExpenseModel.created_by_user), # If you have this relationship
selectinload(ExpenseModel.list),
selectinload(ExpenseModel.group),
selectinload(ExpenseModel.item),
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
)
)
result = await db.execute(stmt)
loaded_expense = result.scalar_one_or_none()
if loaded_expense is None:
await transaction.rollback() # Should be handled by context manager
raise ExpenseOperationError("Failed to load expense after creation.")
await transaction.commit()
return loaded_expense
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are business logic validation errors, re-raise them.
# If a transaction was started, the context manager handles rollback.
raise
except IntegrityError as e:
# Context manager handles rollback.
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
except SQLAlchemyError as e:
# Context manager handles rollback.
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]: async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
@ -197,39 +239,32 @@ async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate)
async def _generate_expense_splits( async def _generate_expense_splits(
db: AsyncSession, db: AsyncSession,
db_expense: ExpenseModel, expense_model: ExpenseModel,
expense_in: ExpenseCreate, expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal] **kwargs: Any
) -> PyList[ExpenseSplitModel]: ) -> PyList[ExpenseSplitModel]:
"""Generates appropriate expense splits based on split type.""" """Generates appropriate expense splits based on split type."""
splits_to_create: PyList[ExpenseSplitModel] = [] splits_to_create: PyList[ExpenseSplitModel] = []
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
# Create splits based on the split type # Create splits based on the split type
if expense_in.split_type == SplitTypeEnum.EQUAL: if expense_in.split_type == SplitTypeEnum.EQUAL:
splits_to_create = await _create_equal_splits( splits_to_create = await _create_equal_splits(**common_args)
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS: elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
splits_to_create = await _create_exact_amount_splits( splits_to_create = await _create_exact_amount_splits(**common_args)
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE: elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
splits_to_create = await _create_percentage_splits( splits_to_create = await _create_percentage_splits(**common_args)
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.SHARES: elif expense_in.split_type == SplitTypeEnum.SHARES:
splits_to_create = await _create_shares_splits( splits_to_create = await _create_shares_splits(**common_args)
db, db_expense, expense_in, round_money
)
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED: elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
splits_to_create = await _create_item_based_splits( splits_to_create = await _create_item_based_splits(**common_args)
db, db_expense, expense_in, round_money
)
else: else:
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}") raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
@ -240,29 +275,24 @@ async def _generate_expense_splits(
return splits_to_create return splits_to_create
async def _create_equal_splits( async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates equal splits among users.""" """Creates equal splits among users."""
users_for_splitting = await get_users_for_splitting( users_for_splitting = await get_users_for_splitting(
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
) )
if not users_for_splitting: if not users_for_splitting:
raise InvalidOperationError("No users found for EQUAL split.") raise InvalidOperationError("No users found for EQUAL split.")
num_users = len(users_for_splitting) num_users = len(users_for_splitting)
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users)) amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
remainder = db_expense.total_amount - (amount_per_user * num_users) remainder = expense_model.total_amount - (amount_per_user * num_users)
splits = [] splits = []
for i, user in enumerate(users_for_splitting): for i, user in enumerate(users_for_splitting):
split_amount = amount_per_user split_amount = amount_per_user
if i == 0 and remainder != Decimal('0'): if i == 0 and remainder != Decimal('0'):
split_amount = round_money(amount_per_user + remainder) split_amount = round_money_func(amount_per_user + remainder)
splits.append(ExpenseSplitModel( splits.append(ExpenseSplitModel(
user_id=user.id, user_id=user.id,
@ -272,12 +302,7 @@ async def _create_equal_splits(
return splits return splits
async def _create_exact_amount_splits( async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits with exact amounts.""" """Creates splits with exact amounts."""
if not expense_in.splits_in: if not expense_in.splits_in:
@ -293,7 +318,7 @@ async def _create_exact_amount_splits(
if split_in.owed_amount <= Decimal('0'): if split_in.owed_amount <= Decimal('0'):
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.") raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
rounded_amount = round_money(split_in.owed_amount) rounded_amount = round_money_func(split_in.owed_amount)
current_total += rounded_amount current_total += rounded_amount
splits.append(ExpenseSplitModel( splits.append(ExpenseSplitModel(
@ -301,20 +326,15 @@ async def _create_exact_amount_splits(
owed_amount=rounded_amount owed_amount=rounded_amount
)) ))
if round_money(current_total) != db_expense.total_amount: if round_money_func(current_total) != expense_model.total_amount:
raise InvalidOperationError( raise InvalidOperationError(
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})." f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
) )
return splits return splits
async def _create_percentage_splits( async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on percentages.""" """Creates splits based on percentages."""
if not expense_in.splits_in: if not expense_in.splits_in:
@ -334,7 +354,7 @@ async def _create_percentage_splits(
) )
total_percentage += split_in.share_percentage total_percentage += split_in.share_percentage
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100"))) owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
current_total += owed_amount current_total += owed_amount
splits.append(ExpenseSplitModel( splits.append(ExpenseSplitModel(
@ -343,23 +363,18 @@ async def _create_percentage_splits(
share_percentage=split_in.share_percentage share_percentage=split_in.share_percentage
)) ))
if round_money(total_percentage) != Decimal("100.00"): if round_money_func(total_percentage) != Decimal("100.00"):
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.") raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
# Adjust for rounding differences # Adjust for rounding differences
if current_total != db_expense.total_amount and splits: if current_total != expense_model.total_amount and splits:
diff = db_expense.total_amount - current_total diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits return splits
async def _create_shares_splits( async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on shares.""" """Creates splits based on shares."""
if not expense_in.splits_in: if not expense_in.splits_in:
@ -381,7 +396,7 @@ async def _create_shares_splits(
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.") raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares) share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
owed_amount = round_money(db_expense.total_amount * share_ratio) owed_amount = round_money_func(expense_model.total_amount * share_ratio)
current_total += owed_amount current_total += owed_amount
splits.append(ExpenseSplitModel( splits.append(ExpenseSplitModel(
@ -391,31 +406,26 @@ async def _create_shares_splits(
)) ))
# Adjust for rounding differences # Adjust for rounding differences
if current_total != db_expense.total_amount and splits: if current_total != expense_model.total_amount and splits:
diff = db_expense.total_amount - current_total diff = expense_model.total_amount - current_total
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff) splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
return splits return splits
async def _create_item_based_splits( async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
db: AsyncSession,
db_expense: ExpenseModel,
expense_in: ExpenseCreate,
round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
"""Creates splits based on items in a shopping list.""" """Creates splits based on items in a shopping list."""
if not expense_in.list_id: if not expense_model.list_id:
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.") raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
if expense_in.splits_in: if expense_in.splits_in:
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.") logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
# Build query to fetch relevant items # Build query to fetch relevant items
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id) items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
if expense_in.item_id: if expense_model.item_id:
items_query = items_query.where(ItemModel.id == expense_in.item_id) items_query = items_query.where(ItemModel.id == expense_model.item_id)
else: else:
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0"))) items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
@ -425,9 +435,9 @@ async def _create_item_based_splits(
if not relevant_items: if not relevant_items:
error_msg = ( error_msg = (
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}." f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
if expense_in.item_id else if expense_model.item_id else
f"List {expense_in.list_id} has no priced items to base the expense on." f"List {expense_model.list_id} has no priced items to base the expense on."
) )
raise InvalidOperationError(error_msg) raise InvalidOperationError(error_msg)
@ -438,9 +448,9 @@ async def _create_item_based_splits(
for item in relevant_items: for item in relevant_items:
if item.price is None or item.price <= Decimal("0"): if item.price is None or item.price <= Decimal("0"):
if expense_in.item_id: if expense_model.item_id:
raise InvalidOperationError( raise InvalidOperationError(
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense." f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
) )
continue continue
@ -454,13 +464,13 @@ async def _create_item_based_splits(
if processed_items == 0: if processed_items == 0:
raise InvalidOperationError( raise InvalidOperationError(
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense." f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
) )
# Validate total matches calculated total # Validate total matches calculated total
if round_money(calculated_total) != db_expense.total_amount: if round_money_func(calculated_total) != expense_model.total_amount:
raise InvalidOperationError( raise InvalidOperationError(
f"Expense total amount ({db_expense.total_amount}) does not match the " f"Expense total amount ({expense_model.total_amount}) does not match the "
f"calculated total from item prices ({calculated_total})." f"calculated total from item prices ({calculated_total})."
) )
@ -469,7 +479,7 @@ async def _create_item_based_splits(
for user_id, owed_amount in user_owed_amounts.items(): for user_id, owed_amount in user_owed_amounts.items():
splits.append(ExpenseSplitModel( splits.append(ExpenseSplitModel(
user_id=user_id, user_id=user_id,
owed_amount=round_money(owed_amount) owed_amount=round_money_func(owed_amount)
)) ))
return splits return splits

View File

@ -4,7 +4,7 @@ from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # For eager loading members from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List from typing import Optional, List
from sqlalchemy import func from sqlalchemy import delete, func
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate from app.schemas.group import GroupCreate
@ -24,10 +24,23 @@ from app.core.exceptions import (
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner.""" """Creates a group and adds the creator as the owner."""
try: try:
async with db.begin(): # Defensive check: if a transaction is already active, try to roll it back.
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration).
if db.in_transaction():
# Log this occurrence if possible, as it's unexpected.
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
try:
await db.rollback() # Attempt to clear any existing transaction
except SQLAlchemyError as e_rb:
# Log e_rb if possible
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
# Re-raise or handle as a critical error, as the session state is uncertain.
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
async with db.begin(): # Now attempt to start a clean transaction
db_group = GroupModel(name=group_in.name, created_by_id=creator_id) db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group) db.add(db_group)
await db.flush() await db.flush() # Assigns ID to db_group
db_user_group = UserGroupModel( db_user_group = UserGroupModel(
user_id=creator_id, user_id=creator_id,
@ -35,15 +48,30 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
role=UserRoleEnum.owner role=UserRoleEnum.owner
) )
db.add(db_user_group) db.add(db_user_group)
await db.flush() await db.flush() # Commits user_group, links to group
await db.refresh(db_group)
return db_group # After creation and linking, explicitly load the group with its member associations and users
stmt = (
select(GroupModel)
.where(GroupModel.id == db_group.id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
result = await db.execute(stmt)
loaded_group = result.scalar_one_or_none()
if loaded_group is None:
# This should not happen if we just created it, but as a safeguard
raise GroupOperationError("Failed to load group after creation.")
return loaded_group
except IntegrityError as e: except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create group: {str(e)}") raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create group: {str(e)}") raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
"""Gets all groups a user is a member of.""" """Gets all groups a user is a member of."""
@ -52,7 +80,9 @@ async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
select(GroupModel) select(GroupModel)
.join(UserGroupModel) .join(UserGroupModel)
.where(UserGroupModel.user_id == user_id) .where(UserGroupModel.user_id == user_id)
.options(selectinload(GroupModel.member_associations)) .options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
) )
return result.scalars().all() return result.scalars().all()
except OperationalError as e: except OperationalError as e:
@ -106,18 +136,34 @@ async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int)
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]: async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
"""Adds a user to a group if they aren't already a member.""" """Adds a user to a group if they aren't already a member."""
try: try:
async with db.begin(): # Check if user is already a member before starting a transaction
existing = await db.execute( existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) existing_result = await db.execute(existing_stmt)
) if existing_result.scalar_one_or_none():
if existing.scalar_one_or_none(): return None
return None
# Use a single transaction
async with db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group) db.add(db_user_group)
await db.flush() await db.flush() # Assigns ID to db_user_group
await db.refresh(db_user_group)
return db_user_group # Eagerly load the 'user' and 'group' relationships for the response
stmt = (
select(UserGroupModel)
.where(UserGroupModel.id == db_user_group.id)
.options(
selectinload(UserGroupModel.user),
selectinload(UserGroupModel.group)
)
)
result = await db.execute(stmt)
loaded_user_group = result.scalar_one_or_none()
if loaded_user_group is None:
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
return loaded_user_group
except IntegrityError as e: except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}") raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e: except OperationalError as e:

View File

@ -3,10 +3,19 @@ import secrets
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete # Import delete statement from sqlalchemy import delete # Import delete statement
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from typing import Optional from typing import Optional
from app.models import Invite as InviteModel from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
InviteOperationError # Add new specific exception
)
# Invite codes should be reasonably unique, but handle potential collision # Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5 MAX_CODE_GENERATION_ATTEMPTS = 5
@ -14,56 +23,121 @@ MAX_CODE_GENERATION_ATTEMPTS = 5
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]: async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
"""Creates a new invite code for a group.""" """Creates a new invite code for a group."""
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
code = None
attempts = 0
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe) potential_code = None
while attempts < MAX_CODE_GENERATION_ATTEMPTS: for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
attempts += 1
potential_code = secrets.token_urlsafe(16) potential_code = secrets.token_urlsafe(16)
# Check if an *active* invite with this code already exists # Check if an *active* invite with this code already exists (outside main transaction for now)
existing = await db.execute( # Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError.
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) # This check reduces collision chance before attempting transaction.
) existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
if existing.scalar_one_or_none() is None: existing_result = await db.execute(existing_check_stmt)
code = potential_code if existing_result.scalar_one_or_none() is None:
break break # Found a potentially unique code
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
if code is None: try:
# Failed to generate a unique code after several attempts async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
return None # Final check within transaction to be absolutely sure before insert
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
# Extremely unlikely if previous check passed, but handles race condition
await transaction.rollback()
raise InviteOperationError("Invite code collision detected during transaction.")
db_invite = InviteModel( db_invite = InviteModel(
code=code, code=potential_code,
group_id=group_id, group_id=group_id,
created_by_id=creator_id, created_by_id=creator_id,
expires_at=expires_at, expires_at=expires_at,
is_active=True is_active=True
) )
db.add(db_invite) db.add(db_invite)
await db.commit() await db.flush() # Assigns ID
await db.refresh(db_invite)
return db_invite # Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
await transaction.rollback()
raise InviteOperationError("Failed to load invite after creation.")
await transaction.commit()
return loaded_invite
except IntegrityError as e: # Catch if DB unique constraint on code is violated
# Rollback handled by context manager
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code.""" """Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
result = await db.execute( try:
select(InviteModel).where( stmt = (
InviteModel.code == code, select(InviteModel).where(
InviteModel.is_active == True, InviteModel.code == code,
InviteModel.expires_at > now InviteModel.is_active == True,
InviteModel.expires_at > now
)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
) )
) result = await db.execute(stmt)
return result.scalars().first() return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used).""" """Marks an invite as inactive (used) and reloads with relationships."""
invite.is_active = False try:
db.add(invite) # Add to session to track change async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.commit() invite.is_active = False
await db.refresh(invite) db.add(invite) # Add to session to track change
return invite await db.flush() # Persist is_active change
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in
await transaction.rollback()
raise InviteOperationError("Failed to load invite after deactivation.")
await transaction.commit()
return updated_invite
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions
# Example: class InviteOperationError(AppException): pass
# Optional: Function to periodically delete old, inactive invites # Optional: Function to periodically delete old, inactive invites
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... # async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...

View File

@ -1,12 +1,13 @@
# app/crud/item.py # app/crud/item.py
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList from typing import Optional, List as PyList
from datetime import datetime, timezone from datetime import datetime, timezone
from app.models import Item as ItemModel from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
from app.schemas.item import ItemCreate, ItemUpdate from app.schemas.item import ItemCreate, ItemUpdate
from app.core.exceptions import ( from app.core.exceptions import (
ItemNotFoundError, ItemNotFoundError,
@ -14,46 +15,65 @@ from app.core.exceptions import (
DatabaseIntegrityError, DatabaseIntegrityError,
DatabaseQueryError, DatabaseQueryError,
DatabaseTransactionError, DatabaseTransactionError,
ConflictError ConflictError,
ItemOperationError # Add if specific item operation errors are needed
) )
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
"""Creates a new item record for a specific list.""" """Creates a new item record for a specific list."""
try: try:
db_item = ItemModel( async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
name=item_in.name, db_item = ItemModel(
quantity=item_in.quantity, name=item_in.name,
list_id=list_id, quantity=item_in.quantity,
added_by_id=user_id, list_id=list_id,
is_complete=False # Default on creation added_by_id=user_id,
# version is implicitly set to 1 by model default is_complete=False
) )
db.add(db_item) db.add(db_item)
await db.flush() await db.flush() # Assigns ID
await db.refresh(db_item)
await db.commit() # Explicitly commit here # Re-fetch with relationships
return db_item stmt = (
select(ItemModel)
.where(ItemModel.id == db_item.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user) # Will be None but loads relationship
)
)
result = await db.execute(stmt)
loaded_item = result.scalar_one_or_none()
if loaded_item is None:
await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
await transaction.commit()
return loaded_item
except IntegrityError as e: except IntegrityError as e:
await db.rollback() # Rollback on integrity error # Context manager handles rollback on error
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}") raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
except OperationalError as e: except OperationalError as e:
await db.rollback() # Rollback on operational error
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback() # Rollback on other SQLAlchemy errors
raise DatabaseTransactionError(f"Failed to create item: {str(e)}") raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
except Exception as e: # Catch any other exception and attempt rollback # Removed generic Exception block as SQLAlchemyError should cover DB issues,
await db.rollback() # and context manager handles rollback.
raise # Re-raise the original exception
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]: async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
"""Gets all items belonging to a specific list, ordered by creation time.""" """Gets all items belonging to a specific list, ordered by creation time."""
try: try:
result = await db.execute( stmt = (
select(ItemModel) select(ItemModel)
.where(ItemModel.list_id == list_id) .where(ItemModel.list_id == list_id)
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred .options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user)
)
.order_by(ItemModel.created_at.asc())
) )
result = await db.execute(stmt)
return result.scalars().all() return result.scalars().all()
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
@ -63,7 +83,16 @@ async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemMod
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
"""Gets a single item by its ID.""" """Gets a single item by its ID."""
try: try:
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id)) stmt = (
select(ItemModel)
.where(ItemModel.id == item_id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list) # Often useful to get the parent list
)
)
result = await db.execute(stmt)
return result.scalars().first() return result.scalars().first()
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
@ -73,59 +102,72 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
"""Updates an existing item record, checking for version conflicts.""" """Updates an existing item record, checking for version conflicts."""
try: try:
# Check version conflict async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if item_db.version != item_in.version: if item_db.version != item_in.version:
raise ConflictError( # No need to rollback here, as the transaction hasn't committed.
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " # The context manager will handle rollback if an exception is raised.
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh." raise ConflictError(
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
)
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
if 'is_complete' in update_data:
if update_data['is_complete'] is True:
if item_db.completed_by_id is None:
update_data['completed_by_id'] = user_id
else:
update_data['completed_by_id'] = None
for key, value in update_data.items():
setattr(item_db, key, value)
item_db.version += 1
db.add(item_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(ItemModel)
.where(ItemModel.id == item_db.id)
.options(
selectinload(ItemModel.added_by_user),
selectinload(ItemModel.completed_by_user),
selectinload(ItemModel.list)
)
) )
result = await db.execute(stmt)
updated_item = result.scalar_one_or_none()
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version if updated_item is None: # Should not happen
# Rollback will be handled by context manager on raise
raise ItemOperationError("Failed to load item after update.")
# Special handling for is_complete await transaction.commit()
if 'is_complete' in update_data: return updated_item
if update_data['is_complete'] is True:
if item_db.completed_by_id is None: # Only set if not already completed by someone
update_data['completed_by_id'] = user_id
else:
update_data['completed_by_id'] = None # Clear if marked incomplete
# Apply updates
for key, value in update_data.items():
setattr(item_db, key, value)
item_db.version += 1 # Increment version
db.add(item_db)
await db.flush()
await db.refresh(item_db)
# Commit the transaction if not part of a larger transaction
await db.commit()
return item_db
except IntegrityError as e: except IntegrityError as e:
await db.rollback()
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}") raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
except ConflictError: # Re-raise ConflictError except ConflictError: # Re-raise ConflictError, rollback handled by context manager
await db.rollback()
raise raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback()
raise DatabaseTransactionError(f"Failed to update item: {str(e)}") raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
"""Deletes an item record. Version check should be done by the caller (API endpoint).""" """Deletes an item record. Version check should be done by the caller (API endpoint)."""
try: try:
await db.delete(item_db) async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.commit() await db.delete(item_db)
return None await transaction.commit()
# No return needed for None
except OperationalError as e: except OperationalError as e:
await db.rollback() # Rollback handled by context manager
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback() # Rollback handled by context manager
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
# Ensure ItemOperationError is defined in app.core.exceptions if used
# Example: class ItemOperationError(AppException): pass

View File

@ -17,15 +17,14 @@ from app.core.exceptions import (
DatabaseIntegrityError, DatabaseIntegrityError,
DatabaseQueryError, DatabaseQueryError,
DatabaseTransactionError, DatabaseTransactionError,
ConflictError ConflictError,
ListOperationError
) )
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record.""" """Creates a new list record."""
try: try:
# Check if we're already in a transaction async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if db.in_transaction():
# If we're already in a transaction, just create the list
db_list = ListModel( db_list = ListModel(
name=list_in.name, name=list_in.name,
description=list_in.description, description=list_in.description,
@ -34,23 +33,27 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
is_complete=False is_complete=False
) )
db.add(db_list) db.add(db_list)
await db.flush() await db.flush() # Assigns ID
await db.refresh(db_list)
return db_list # Re-fetch with relationships for the response
else: stmt = (
# If no transaction is active, start one select(ListModel)
async with db.begin(): .where(ListModel.id == db_list.id)
db_list = ListModel( .options(
name=list_in.name, selectinload(ListModel.creator),
description=list_in.description, selectinload(ListModel.group)
group_id=list_in.group_id, # selectinload(ListModel.items) # Optionally add if items are always needed in response
created_by_id=creator_id,
is_complete=False
) )
db.add(db_list) )
await db.flush() result = await db.execute(stmt)
await db.refresh(db_list) loaded_list = result.scalar_one_or_none()
return db_list
if loaded_list is None:
await transaction.rollback()
raise ListOperationError("Failed to load list after creation.")
await transaction.commit()
return loaded_list
except IntegrityError as e: except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}") raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e: except OperationalError as e:
@ -66,14 +69,22 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
) )
user_group_ids = group_ids_result.scalars().all() user_group_ids = group_ids_result.scalars().all()
# Build conditions for the OR clause dynamically
conditions = [ conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)) and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
] ]
if user_group_ids: # Only add the IN clause if there are group IDs if user_group_ids:
conditions.append(ListModel.group_id.in_(user_group_ids)) conditions.append(ListModel.group_id.in_(user_group_ids))
query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc()) query = (
select(ListModel)
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Consider if items are needed for list previews
)
.order_by(ListModel.updated_at.desc())
)
result = await db.execute(query) result = await db.execute(query)
return result.scalars().all() return result.scalars().all()
@ -85,11 +96,17 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]: async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items.""" """Gets a single list by ID, optionally loading its items."""
try: try:
query = select(ListModel).where(ListModel.id == list_id) query = (
select(ListModel)
.where(ListModel.id == list_id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
)
if load_items: if load_items:
query = query.options( query = query.options(
selectinload(ListModel.items) selectinload(ListModel.items).options(
.options(
joinedload(ItemModel.added_by_user), joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user) joinedload(ItemModel.completed_by_user)
) )
@ -104,8 +121,9 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel: async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record, checking for version conflicts.""" """Updates an existing list record, checking for version conflicts."""
try: try:
async with db.begin(): async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
await transaction.rollback() # Rollback before raising
raise ConflictError( raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
@ -118,34 +136,54 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
list_db.version += 1 list_db.version += 1
db.add(list_db) db.add(list_db) # Add the already attached list_db to mark it dirty for the session
await db.flush() await db.flush()
await db.refresh(list_db)
return list_db # Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == list_db.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
await transaction.rollback()
raise ListOperationError("Failed to load list after update.")
await transaction.commit()
return updated_list
except IntegrityError as e: except IntegrityError as e:
await db.rollback() # Ensure rollback if not handled by context manager (though it should be)
if db.in_transaction(): await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}") raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e: except OperationalError as e:
await db.rollback() if db.in_transaction(): await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}") raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError: except ConflictError:
await db.rollback() # Already rolled back or will be by context manager if transaction was started here
raise raise
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback() if db.in_transaction(): await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}") raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None: async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record. Version check should be done by the caller (API endpoint).""" """Deletes a list record. Version check should be done by the caller (API endpoint)."""
try: try:
async with db.begin(): async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db) await db.delete(list_db)
return None await transaction.commit() # Explicit commit
# return None # Already implicitly returns None
except OperationalError as e: except OperationalError as e:
await db.rollback() # Rollback should be handled by the async with block on exception
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}") raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
await db.rollback() # Rollback should be handled by the async with block on exception
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}") raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel: async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
@ -212,39 +250,48 @@ async def get_list_by_name_and_group(
db: AsyncSession, db: AsyncSession,
name: str, name: str,
group_id: Optional[int], group_id: Optional[int],
user_id: int user_id: int # user_id is for permission check, not direct list attribute
) -> Optional[ListModel]: ) -> Optional[ListModel]:
""" """
Gets a list by name and group, ensuring the user has permission to access it. Gets a list by name and group, ensuring the user has permission to access it.
Used for conflict resolution when creating lists. Used for conflict resolution when creating lists.
""" """
try: try:
# Build the base query # Base query for the list itself
query = select(ListModel).where(ListModel.name == name) base_query = select(ListModel).where(ListModel.name == name)
# Add group condition
if group_id is not None: if group_id is not None:
query = query.where(ListModel.group_id == group_id) base_query = base_query.where(ListModel.group_id == group_id)
else: else:
query = query.where(ListModel.group_id.is_(None)) base_query = base_query.where(ListModel.group_id.is_(None))
# Add permission conditions # Add eager loading for common relationships
conditions = [ base_query = base_query.options(
ListModel.created_by_id == user_id # User is creator selectinload(ListModel.creator),
] selectinload(ListModel.group)
if group_id is not None: )
# User is member of the group
conditions.append(
and_(
ListModel.group_id == group_id,
ListModel.created_by_id != user_id # Not the creator
)
)
query = query.where(or_(*conditions)) list_result = await db.execute(base_query)
target_list = list_result.scalar_one_or_none()
if not target_list:
return None
# Permission check
is_creator = target_list.created_by_id == user_id
if is_creator:
return target_list
if target_list.group_id:
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
if is_member_of_group:
return target_list
# If not creator and (not a group list or not a member of the group list)
return None
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:

View File

@ -3,84 +3,135 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_ from sqlalchemy import or_
from decimal import Decimal from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from decimal import Decimal, ROUND_HALF_UP
from typing import List as PyList, Optional, Sequence from typing import List as PyList, Optional, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from app.models import ( from app.models import (
Settlement as SettlementModel, Settlement as SettlementModel,
User as UserModel, User as UserModel,
Group as GroupModel Group as GroupModel,
UserGroup as UserGroupModel
)
from app.schemas.expense import SettlementCreate, SettlementUpdate
from app.core.exceptions import (
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
SettlementOperationError,
ConflictError
) )
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
"""Creates a new settlement record.""" """Creates a new settlement record."""
# Validate Payer, Payee, and Group exist
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
if not payer:
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Optional: Check if current_user_id is part of the group or is one of the parties involved
# This is more of an API-level permission check but could be added here if strict.
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
# if not is_in_group.first():
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
try: try:
await db.commit() async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"]) payer = await db.get(UserModel, settlement_in.paid_by_user_id)
except Exception as e: if not payer:
await db.rollback() raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
if not payee:
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
raise InvalidOperationError("Payer and Payee cannot be the same user.")
group = await db.get(GroupModel, settlement_in.group_id)
if not group:
raise GroupNotFoundError(settlement_in.group_id)
# Permission check example (can be in API layer too)
# if current_user_id not in [payer.id, payee.id]:
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
# is_member_result = await db.execute(is_member_stmt)
# if not is_member_result.scalar_one_or_none():
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
db_settlement = SettlementModel(
group_id=settlement_in.group_id,
paid_by_user_id=settlement_in.paid_by_user_id,
paid_to_user_id=settlement_in.paid_to_user_id,
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
description=settlement_in.description
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
)
db.add(db_settlement)
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == db_settlement.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
)
result = await db.execute(stmt)
loaded_settlement = result.scalar_one_or_none()
if loaded_settlement is None:
await transaction.rollback() # Should be handled by context manager
raise SettlementOperationError("Failed to load settlement after creation.")
await transaction.commit()
return loaded_settlement
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
# These are validation errors, re-raise them.
# If a transaction was started, context manager handles rollback.
raise
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
return db_settlement
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]: async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
result = await db.execute( try:
select(SettlementModel) result = await db.execute(
.options( select(SettlementModel)
selectinload(SettlementModel.payer), .options(
selectinload(SettlementModel.payee), selectinload(SettlementModel.payer),
selectinload(SettlementModel.group) selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
.where(SettlementModel.id == settlement_id)
) )
.where(SettlementModel.id == settlement_id) return result.scalars().first()
) except OperationalError as e:
return result.scalars().first() raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
result = await db.execute( try:
select(SettlementModel) result = await db.execute(
.where(SettlementModel.group_id == group_id) select(SettlementModel)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .where(SettlementModel.group_id == group_id)
.offset(skip).limit(limit) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee)) .offset(skip).limit(limit)
) .options(
return result.scalars().all() selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
async def get_settlements_involving_user( async def get_settlements_involving_user(
db: AsyncSession, db: AsyncSession,
@ -89,18 +140,28 @@ async def get_settlements_involving_user(
skip: int = 0, skip: int = 0,
limit: int = 100 limit: int = 100
) -> Sequence[SettlementModel]: ) -> Sequence[SettlementModel]:
query = ( try:
select(SettlementModel) query = (
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id)) select(SettlementModel)
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc()) .where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
.offset(skip).limit(limit) .order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group)) .offset(skip).limit(limit)
) .options(
if group_id: selectinload(SettlementModel.payer),
query = query.where(SettlementModel.group_id == group_id) selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
)
if group_id:
query = query.where(SettlementModel.group_id == group_id)
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
result = await db.execute(query)
return result.scalars().all()
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel: async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
""" """
@ -108,58 +169,100 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
Only allows updates to description and settlement_date. Only allows updates to description and settlement_date.
Requires version matching for optimistic locking. Requires version matching for optimistic locking.
Assumes SettlementUpdate schema includes a version field. Assumes SettlementUpdate schema includes a version field.
Assumes SettlementModel has version and updated_at fields.
""" """
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version does not match current version {settlement_db.version}. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
else:
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
pass # No actual updatable fields provided, but version matched.
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
settlement_db.updated_at = datetime.now(timezone.utc)
try: try:
await db.commit() async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
await db.refresh(settlement_db) # Ensure the settlement_db passed is managed by the current session if not already.
except Exception as e: # This is usually true if fetched by an endpoint dependency using the same session.
await db.rollback() # If not, `db.add(settlement_db)` might be needed before modification if it's detached.
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
if settlement_db.version != settlement_in.version:
raise ConflictError( # Make sure ConflictError is defined in exceptions
f"Settlement (ID: {settlement_db.id}) has been modified. "
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
)
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
allowed_to_update = {"description", "settlement_date"}
updated_something = False
for field, value in update_data.items():
if field in allowed_to_update:
setattr(settlement_db, field, value)
updated_something = True
# Silently ignore fields not allowed to update or raise error:
# else:
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
# No updatable fields were actually provided, or they didn't change
# Still, we might want to return the re-loaded settlement if version matched.
pass
settlement_db.version += 1
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
db.add(settlement_db) # Mark as dirty
await db.flush()
# Re-fetch with relationships
stmt = (
select(SettlementModel)
.where(SettlementModel.id == settlement_db.id)
.options(
selectinload(SettlementModel.payer),
selectinload(SettlementModel.payee),
selectinload(SettlementModel.group)
)
)
result = await db.execute(stmt)
updated_settlement = result.scalar_one_or_none()
if updated_settlement is None: # Should not happen
await transaction.rollback()
raise SettlementOperationError("Failed to load settlement after update.")
await transaction.commit()
return updated_settlement
except ConflictError as e: # ConflictError should be defined in exceptions
raise
except InvalidOperationError as e:
raise
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
return settlement_db
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None: async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
""" """
Deletes a settlement. Requires version matching if expected_version is provided. Deletes a settlement. Requires version matching if expected_version is provided.
Assumes SettlementModel has a version field. Assumes SettlementModel has a version field.
""" """
if expected_version is not None:
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version. Please refresh.",
# status_code=status.HTTP_409_CONFLICT
)
await db.delete(settlement_db)
try: try:
await db.commit() async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
except Exception as e: if expected_version is not None:
await db.rollback() if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}") raise ConflictError( # Make sure ConflictError is defined
return None f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
)
await db.delete(settlement_db)
await transaction.commit()
except ConflictError as e: # ConflictError should be defined
raise
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
# Example: class SettlementOperationError(AppException): pass
# Example: class ConflictError(AppException): status_code = 409

View File

@ -1,10 +1,11 @@
# app/crud/user.py # app/crud/user.py
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional from typing import Optional
from app.models import User as UserModel # Alias to avoid name clash from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
from app.schemas.user import UserCreate from app.schemas.user import UserCreate
from app.core.security import hash_password from app.core.security import hash_password
from app.core.exceptions import ( from app.core.exceptions import (
@ -13,14 +14,26 @@ from app.core.exceptions import (
DatabaseConnectionError, DatabaseConnectionError,
DatabaseIntegrityError, DatabaseIntegrityError,
DatabaseQueryError, DatabaseQueryError,
DatabaseTransactionError DatabaseTransactionError,
UserOperationError # Add if specific user operation errors are needed
) )
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
"""Fetches a user from the database by email.""" """Fetches a user from the database by email, with common relationships."""
try: try:
async with db.begin(): # db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
result = await db.execute(select(UserModel).filter(UserModel.email == email)) # For a single select, it can be omitted if preferred, session handles connection.
async with db.begin(): # Or remove if only a single select operation
stmt = (
select(UserModel)
.filter(UserModel.email == email)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
selectinload(UserModel.created_groups) # Groups user created
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
return result.scalars().first() return result.scalars().first()
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
@ -28,24 +41,46 @@ async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]
raise DatabaseQueryError(f"Failed to query user: {str(e)}") raise DatabaseQueryError(f"Failed to query user: {str(e)}")
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
"""Creates a new user record in the database.""" """Creates a new user record in the database with common relationships loaded."""
try: try:
async with db.begin(): async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
_hashed_password = hash_password(user_in.password) _hashed_password = hash_password(user_in.password)
db_user = UserModel( db_user = UserModel(
email=user_in.email, email=user_in.email,
password_hash=_hashed_password, hashed_password=_hashed_password, # Field name in model is hashed_password
name=user_in.name name=user_in.name
) )
db.add(db_user) db.add(db_user)
await db.flush() # Flush to get DB-generated values await db.flush() # Flush to get DB-generated values like ID
await db.refresh(db_user)
return db_user # Re-fetch with relationships
stmt = (
select(UserModel)
.where(UserModel.id == db_user.id)
.options(
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group),
selectinload(UserModel.created_groups)
# Add other relationships as needed by UserPublic schema
)
)
result = await db.execute(stmt)
loaded_user = result.scalar_one_or_none()
if loaded_user is None:
await transaction.rollback() # Should be handled by context manager, but explicit
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
await transaction.commit()
return loaded_user
except IntegrityError as e: except IntegrityError as e:
if "unique constraint" in str(e).lower(): # Context manager handles rollback on error
raise EmailAlreadyRegisteredError() if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
raise DatabaseIntegrityError(f"Failed to create user: {str(e)}") raise EmailAlreadyRegisteredError(email=user_in.email)
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
except OperationalError as e: except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}") raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
except SQLAlchemyError as e: except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create user: {str(e)}") raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
# Ensure UserOperationError is defined in app.core.exceptions if used
# Example: class UserOperationError(AppException): pass

View File

@ -36,15 +36,9 @@ async def get_async_session() -> AsyncSession: # type: ignore
Ensures the session is closed after the request. Ensures the session is closed after the request.
""" """
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
try: yield session
yield session # The 'async with' block handles session.close() automatically.
# Commit the transaction if no errors occurred # Commit/rollback should be handled by the functions using the session.
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close() # Not strictly necessary with async context manager, but explicit
# Alias for backward compatibility # Alias for backward compatibility
get_db = get_async_session get_db = get_async_session