Refactor database session management and exception handling across CRUD operations; streamline transaction handling in expense, group, invite, item, list, settlement, and user modules for improved reliability and clarity. Introduce specific operation errors for better error reporting.
This commit is contained in:
parent
515534dcce
commit
7a88ea258a
@ -128,6 +128,14 @@ class DatabaseQueryError(HTTPException):
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class ExpenseOperationError(HTTPException):
|
||||
"""Raised when an expense operation fails."""
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class OCRServiceUnavailableError(HTTPException):
|
||||
"""Raised when the OCR service is unavailable."""
|
||||
def __init__(self):
|
||||
@ -240,6 +248,22 @@ class ListStatusNotFoundError(HTTPException):
|
||||
detail=f"Status for list {list_id} not found"
|
||||
)
|
||||
|
||||
class InviteOperationError(HTTPException):
|
||||
"""Raised when an invite operation fails."""
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class SettlementOperationError(HTTPException):
|
||||
"""Raised when a settlement operation fails."""
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class ConflictError(HTTPException):
|
||||
"""Raised when an optimistic lock version conflict occurs."""
|
||||
def __init__(self, detail: str):
|
||||
@ -283,3 +307,19 @@ class JWTUnexpectedError(HTTPException):
|
||||
detail=settings.JWT_UNEXPECTED_ERROR.format(error=error),
|
||||
headers={settings.AUTH_HEADER_NAME: f"{settings.AUTH_HEADER_PREFIX} error=\"invalid_token\""}
|
||||
)
|
||||
|
||||
class ListOperationError(HTTPException):
|
||||
"""Raised when a list operation fails."""
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail
|
||||
)
|
||||
|
||||
class ItemOperationError(HTTPException):
|
||||
"""Raised when an item operation fails."""
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail
|
||||
)
|
@ -4,7 +4,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
|
||||
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
|
||||
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
|
||||
from datetime import datetime, timezone # Added timezone
|
||||
|
||||
from app.models import (
|
||||
@ -23,7 +23,12 @@ from app.core.exceptions import (
|
||||
ListNotFoundError,
|
||||
GroupNotFoundError,
|
||||
UserNotFoundError,
|
||||
InvalidOperationError # Import the new exception
|
||||
InvalidOperationError, # Import the new exception
|
||||
DatabaseConnectionError, # Added
|
||||
DatabaseIntegrityError, # Added
|
||||
DatabaseQueryError, # Added
|
||||
DatabaseTransactionError,# Added
|
||||
ExpenseOperationError # Added specific exception
|
||||
)
|
||||
|
||||
# Placeholder for InvalidOperationError if not defined in app.core.exceptions
|
||||
@ -108,60 +113,97 @@ async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_us
|
||||
GroupNotFoundError: If specified group doesn't exist
|
||||
InvalidOperationError: For various validation failures
|
||||
"""
|
||||
# Helper function to round decimals consistently
|
||||
def round_money(amount: Decimal) -> Decimal:
|
||||
return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
||||
|
||||
# 1. Context Validation
|
||||
# Validate basic context requirements first
|
||||
if not expense_in.list_id and not expense_in.group_id:
|
||||
raise InvalidOperationError("Expense must be associated with a list or a group.")
|
||||
|
||||
# 2. User Validation
|
||||
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
|
||||
|
||||
# 3. List/Group Context Resolution
|
||||
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||
|
||||
# 4. Create the expense object
|
||||
db_expense = ExpenseModel(
|
||||
description=expense_in.description,
|
||||
total_amount=round_money(expense_in.total_amount),
|
||||
currency=expense_in.currency or "USD",
|
||||
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
||||
split_type=expense_in.split_type,
|
||||
list_id=expense_in.list_id,
|
||||
group_id=final_group_id,
|
||||
item_id=expense_in.item_id,
|
||||
paid_by_user_id=expense_in.paid_by_user_id,
|
||||
created_by_user_id=current_user_id # Track who created this expense
|
||||
)
|
||||
|
||||
# 5. Generate splits based on split type
|
||||
splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
|
||||
|
||||
# 6. Single transaction for expense and all splits
|
||||
try:
|
||||
db.add(db_expense)
|
||||
await db.flush() # Get expense ID without committing
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
# 1. Validate payer
|
||||
payer = await db.get(UserModel, expense_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")
|
||||
|
||||
# Update all splits with the expense ID
|
||||
for split in splits_to_create:
|
||||
split.expense_id = db_expense.id
|
||||
# 2. Context Resolution and Validation (now part of the transaction)
|
||||
if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
|
||||
raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
|
||||
|
||||
db.add_all(splits_to_create)
|
||||
await db.commit()
|
||||
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||
# Further validation for item_id if provided
|
||||
db_item_instance = None
|
||||
if expense_in.item_id:
|
||||
db_item_instance = await db.get(ItemModel, expense_in.item_id)
|
||||
if not db_item_instance:
|
||||
raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
|
||||
# Potentially link item's list/group if not already set on expense_in
|
||||
if db_item_instance.list_id and not expense_in.list_id:
|
||||
expense_in.list_id = db_item_instance.list_id
|
||||
# Re-resolve context if list_id was derived from item
|
||||
final_group_id = await _resolve_expense_context(db, expense_in)
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
|
||||
raise InvalidOperationError(f"Failed to save expense: {str(e)}")
|
||||
# 3. Create the ExpenseModel instance
|
||||
db_expense = ExpenseModel(
|
||||
description=expense_in.description,
|
||||
total_amount=_round_money(expense_in.total_amount),
|
||||
currency=expense_in.currency or "USD",
|
||||
expense_date=expense_in.expense_date or datetime.now(timezone.utc),
|
||||
split_type=expense_in.split_type,
|
||||
list_id=expense_in.list_id,
|
||||
group_id=final_group_id, # Use resolved group_id
|
||||
item_id=expense_in.item_id,
|
||||
paid_by_user_id=expense_in.paid_by_user_id,
|
||||
created_by_user_id=current_user_id
|
||||
)
|
||||
db.add(db_expense)
|
||||
await db.flush() # Get expense ID
|
||||
|
||||
# Refresh to get the splits relationship populated
|
||||
await db.refresh(db_expense, attribute_names=["splits"])
|
||||
return db_expense
|
||||
# 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
|
||||
splits_to_create = await _generate_expense_splits(
|
||||
db=db,
|
||||
expense_model=db_expense,
|
||||
expense_in=expense_in,
|
||||
current_user_id=current_user_id # Pass for item-based splits needing creator info
|
||||
)
|
||||
|
||||
for split_model in splits_to_create:
|
||||
split_model.expense_id = db_expense.id # Set FK after db_expense has ID
|
||||
db.add_all(splits_to_create)
|
||||
await db.flush() # Persist splits
|
||||
|
||||
# 5. Re-fetch the expense with all necessary relationships for the response
|
||||
stmt = (
|
||||
select(ExpenseModel)
|
||||
.where(ExpenseModel.id == db_expense.id)
|
||||
.options(
|
||||
selectinload(ExpenseModel.paid_by_user),
|
||||
selectinload(ExpenseModel.created_by_user), # If you have this relationship
|
||||
selectinload(ExpenseModel.list),
|
||||
selectinload(ExpenseModel.group),
|
||||
selectinload(ExpenseModel.item),
|
||||
selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_expense = result.scalar_one_or_none()
|
||||
|
||||
if loaded_expense is None:
|
||||
await transaction.rollback() # Should be handled by context manager
|
||||
raise ExpenseOperationError("Failed to load expense after creation.")
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_expense
|
||||
|
||||
except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||
# These are business logic validation errors, re-raise them.
|
||||
# If a transaction was started, the context manager handles rollback.
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
# Context manager handles rollback.
|
||||
logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
|
||||
raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
|
||||
except OperationalError as e:
|
||||
logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
|
||||
raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
# Context manager handles rollback.
|
||||
logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
|
||||
raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")
|
||||
|
||||
|
||||
async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
|
||||
@ -197,39 +239,32 @@ async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate)
|
||||
|
||||
async def _generate_expense_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_model: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
**kwargs: Any
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
"""Generates appropriate expense splits based on split type."""
|
||||
|
||||
splits_to_create: PyList[ExpenseSplitModel] = []
|
||||
|
||||
# Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
|
||||
common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}
|
||||
|
||||
# Create splits based on the split type
|
||||
if expense_in.split_type == SplitTypeEnum.EQUAL:
|
||||
splits_to_create = await _create_equal_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
splits_to_create = await _create_equal_splits(**common_args)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
|
||||
splits_to_create = await _create_exact_amount_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
splits_to_create = await _create_exact_amount_splits(**common_args)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
|
||||
splits_to_create = await _create_percentage_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
splits_to_create = await _create_percentage_splits(**common_args)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.SHARES:
|
||||
splits_to_create = await _create_shares_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
splits_to_create = await _create_shares_splits(**common_args)
|
||||
|
||||
elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
|
||||
splits_to_create = await _create_item_based_splits(
|
||||
db, db_expense, expense_in, round_money
|
||||
)
|
||||
splits_to_create = await _create_item_based_splits(**common_args)
|
||||
|
||||
else:
|
||||
raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
|
||||
@ -240,29 +275,24 @@ async def _generate_expense_splits(
|
||||
return splits_to_create
|
||||
|
||||
|
||||
async def _create_equal_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates equal splits among users."""
|
||||
|
||||
users_for_splitting = await get_users_for_splitting(
|
||||
db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
|
||||
db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
|
||||
)
|
||||
if not users_for_splitting:
|
||||
raise InvalidOperationError("No users found for EQUAL split.")
|
||||
|
||||
num_users = len(users_for_splitting)
|
||||
amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
|
||||
remainder = db_expense.total_amount - (amount_per_user * num_users)
|
||||
amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
|
||||
remainder = expense_model.total_amount - (amount_per_user * num_users)
|
||||
|
||||
splits = []
|
||||
for i, user in enumerate(users_for_splitting):
|
||||
split_amount = amount_per_user
|
||||
if i == 0 and remainder != Decimal('0'):
|
||||
split_amount = round_money(amount_per_user + remainder)
|
||||
split_amount = round_money_func(amount_per_user + remainder)
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=user.id,
|
||||
@ -272,12 +302,7 @@ async def _create_equal_splits(
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_exact_amount_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits with exact amounts."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
@ -293,7 +318,7 @@ async def _create_exact_amount_splits(
|
||||
if split_in.owed_amount <= Decimal('0'):
|
||||
raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
|
||||
|
||||
rounded_amount = round_money(split_in.owed_amount)
|
||||
rounded_amount = round_money_func(split_in.owed_amount)
|
||||
current_total += rounded_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
@ -301,20 +326,15 @@ async def _create_exact_amount_splits(
|
||||
owed_amount=rounded_amount
|
||||
))
|
||||
|
||||
if round_money(current_total) != db_expense.total_amount:
|
||||
if round_money_func(current_total) != expense_model.total_amount:
|
||||
raise InvalidOperationError(
|
||||
f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
|
||||
f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
|
||||
)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_percentage_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on percentages."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
@ -334,7 +354,7 @@ async def _create_percentage_splits(
|
||||
)
|
||||
|
||||
total_percentage += split_in.share_percentage
|
||||
owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
|
||||
owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
|
||||
current_total += owed_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
@ -343,23 +363,18 @@ async def _create_percentage_splits(
|
||||
share_percentage=split_in.share_percentage
|
||||
))
|
||||
|
||||
if round_money(total_percentage) != Decimal("100.00"):
|
||||
if round_money_func(total_percentage) != Decimal("100.00"):
|
||||
raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
|
||||
|
||||
# Adjust for rounding differences
|
||||
if current_total != db_expense.total_amount and splits:
|
||||
diff = db_expense.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
|
||||
if current_total != expense_model.total_amount and splits:
|
||||
diff = expense_model.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_shares_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on shares."""
|
||||
|
||||
if not expense_in.splits_in:
|
||||
@ -381,7 +396,7 @@ async def _create_shares_splits(
|
||||
raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
|
||||
|
||||
share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
|
||||
owed_amount = round_money(db_expense.total_amount * share_ratio)
|
||||
owed_amount = round_money_func(expense_model.total_amount * share_ratio)
|
||||
current_total += owed_amount
|
||||
|
||||
splits.append(ExpenseSplitModel(
|
||||
@ -391,31 +406,26 @@ async def _create_shares_splits(
|
||||
))
|
||||
|
||||
# Adjust for rounding differences
|
||||
if current_total != db_expense.total_amount and splits:
|
||||
diff = db_expense.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
|
||||
if current_total != expense_model.total_amount and splits:
|
||||
diff = expense_model.total_amount - current_total
|
||||
splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
async def _create_item_based_splits(
|
||||
db: AsyncSession,
|
||||
db_expense: ExpenseModel,
|
||||
expense_in: ExpenseCreate,
|
||||
round_money: Callable[[Decimal], Decimal]
|
||||
) -> PyList[ExpenseSplitModel]:
|
||||
async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
|
||||
"""Creates splits based on items in a shopping list."""
|
||||
|
||||
if not expense_in.list_id:
|
||||
if not expense_model.list_id:
|
||||
raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
|
||||
|
||||
if expense_in.splits_in:
|
||||
logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
|
||||
|
||||
# Build query to fetch relevant items
|
||||
items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
|
||||
if expense_in.item_id:
|
||||
items_query = items_query.where(ItemModel.id == expense_in.item_id)
|
||||
items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
|
||||
if expense_model.item_id:
|
||||
items_query = items_query.where(ItemModel.id == expense_model.item_id)
|
||||
else:
|
||||
items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
|
||||
|
||||
@ -425,9 +435,9 @@ async def _create_item_based_splits(
|
||||
|
||||
if not relevant_items:
|
||||
error_msg = (
|
||||
f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
|
||||
if expense_in.item_id else
|
||||
f"List {expense_in.list_id} has no priced items to base the expense on."
|
||||
f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
|
||||
if expense_model.item_id else
|
||||
f"List {expense_model.list_id} has no priced items to base the expense on."
|
||||
)
|
||||
raise InvalidOperationError(error_msg)
|
||||
|
||||
@ -438,9 +448,9 @@ async def _create_item_based_splits(
|
||||
|
||||
for item in relevant_items:
|
||||
if item.price is None or item.price <= Decimal("0"):
|
||||
if expense_in.item_id:
|
||||
if expense_model.item_id:
|
||||
raise InvalidOperationError(
|
||||
f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
|
||||
f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
|
||||
)
|
||||
continue
|
||||
|
||||
@ -454,13 +464,13 @@ async def _create_item_based_splits(
|
||||
|
||||
if processed_items == 0:
|
||||
raise InvalidOperationError(
|
||||
f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
|
||||
f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
|
||||
)
|
||||
|
||||
# Validate total matches calculated total
|
||||
if round_money(calculated_total) != db_expense.total_amount:
|
||||
if round_money_func(calculated_total) != expense_model.total_amount:
|
||||
raise InvalidOperationError(
|
||||
f"Expense total amount ({db_expense.total_amount}) does not match the "
|
||||
f"Expense total amount ({expense_model.total_amount}) does not match the "
|
||||
f"calculated total from item prices ({calculated_total})."
|
||||
)
|
||||
|
||||
@ -469,7 +479,7 @@ async def _create_item_based_splits(
|
||||
for user_id, owed_amount in user_owed_amounts.items():
|
||||
splits.append(ExpenseSplitModel(
|
||||
user_id=user_id,
|
||||
owed_amount=round_money(owed_amount)
|
||||
owed_amount=round_money_func(owed_amount)
|
||||
))
|
||||
|
||||
return splits
|
||||
|
@ -4,7 +4,7 @@ from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload # For eager loading members
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import delete, func
|
||||
|
||||
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
|
||||
from app.schemas.group import GroupCreate
|
||||
@ -24,10 +24,23 @@ from app.core.exceptions import (
|
||||
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
|
||||
"""Creates a group and adds the creator as the owner."""
|
||||
try:
|
||||
async with db.begin():
|
||||
# Defensive check: if a transaction is already active, try to roll it back.
|
||||
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration).
|
||||
if db.in_transaction():
|
||||
# Log this occurrence if possible, as it's unexpected.
|
||||
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
|
||||
try:
|
||||
await db.rollback() # Attempt to clear any existing transaction
|
||||
except SQLAlchemyError as e_rb:
|
||||
# Log e_rb if possible
|
||||
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
|
||||
# Re-raise or handle as a critical error, as the session state is uncertain.
|
||||
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
|
||||
|
||||
async with db.begin(): # Now attempt to start a clean transaction
|
||||
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
|
||||
db.add(db_group)
|
||||
await db.flush()
|
||||
await db.flush() # Assigns ID to db_group
|
||||
|
||||
db_user_group = UserGroupModel(
|
||||
user_id=creator_id,
|
||||
@ -35,15 +48,30 @@ async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int)
|
||||
role=UserRoleEnum.owner
|
||||
)
|
||||
db.add(db_user_group)
|
||||
await db.flush()
|
||||
await db.refresh(db_group)
|
||||
return db_group
|
||||
await db.flush() # Commits user_group, links to group
|
||||
|
||||
# After creation and linking, explicitly load the group with its member associations and users
|
||||
stmt = (
|
||||
select(GroupModel)
|
||||
.where(GroupModel.id == db_group.id)
|
||||
.options(
|
||||
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_group = result.scalar_one_or_none()
|
||||
|
||||
if loaded_group is None:
|
||||
# This should not happen if we just created it, but as a safeguard
|
||||
raise GroupOperationError("Failed to load group after creation.")
|
||||
|
||||
return loaded_group
|
||||
except IntegrityError as e:
|
||||
raise DatabaseIntegrityError(f"Failed to create group: {str(e)}")
|
||||
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"Failed to create group: {str(e)}")
|
||||
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
|
||||
|
||||
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
||||
"""Gets all groups a user is a member of."""
|
||||
@ -52,7 +80,9 @@ async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
|
||||
select(GroupModel)
|
||||
.join(UserGroupModel)
|
||||
.where(UserGroupModel.user_id == user_id)
|
||||
.options(selectinload(GroupModel.member_associations))
|
||||
.options(
|
||||
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except OperationalError as e:
|
||||
@ -106,18 +136,34 @@ async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int)
|
||||
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
|
||||
"""Adds a user to a group if they aren't already a member."""
|
||||
try:
|
||||
async with db.begin():
|
||||
existing = await db.execute(
|
||||
select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
return None
|
||||
# Check if user is already a member before starting a transaction
|
||||
existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
|
||||
existing_result = await db.execute(existing_stmt)
|
||||
if existing_result.scalar_one_or_none():
|
||||
return None
|
||||
|
||||
# Use a single transaction
|
||||
async with db.begin():
|
||||
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
|
||||
db.add(db_user_group)
|
||||
await db.flush()
|
||||
await db.refresh(db_user_group)
|
||||
return db_user_group
|
||||
await db.flush() # Assigns ID to db_user_group
|
||||
|
||||
# Eagerly load the 'user' and 'group' relationships for the response
|
||||
stmt = (
|
||||
select(UserGroupModel)
|
||||
.where(UserGroupModel.id == db_user_group.id)
|
||||
.options(
|
||||
selectinload(UserGroupModel.user),
|
||||
selectinload(UserGroupModel.group)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_user_group = result.scalar_one_or_none()
|
||||
|
||||
if loaded_user_group is None:
|
||||
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
|
||||
|
||||
return loaded_user_group
|
||||
except IntegrityError as e:
|
||||
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
|
||||
except OperationalError as e:
|
||||
|
@ -3,10 +3,19 @@ import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||
from sqlalchemy import delete # Import delete statement
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||
from typing import Optional
|
||||
|
||||
from app.models import Invite as InviteModel
|
||||
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
|
||||
from app.core.exceptions import (
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
InviteOperationError # Add new specific exception
|
||||
)
|
||||
|
||||
# Invite codes should be reasonably unique, but handle potential collision
|
||||
MAX_CODE_GENERATION_ATTEMPTS = 5
|
||||
@ -14,56 +23,121 @@ MAX_CODE_GENERATION_ATTEMPTS = 5
|
||||
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
|
||||
"""Creates a new invite code for a group."""
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
||||
code = None
|
||||
attempts = 0
|
||||
|
||||
# Generate a unique code, retrying if a collision occurs (highly unlikely but safe)
|
||||
while attempts < MAX_CODE_GENERATION_ATTEMPTS:
|
||||
attempts += 1
|
||||
potential_code = None
|
||||
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
|
||||
potential_code = secrets.token_urlsafe(16)
|
||||
# Check if an *active* invite with this code already exists
|
||||
existing = await db.execute(
|
||||
select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||
)
|
||||
if existing.scalar_one_or_none() is None:
|
||||
code = potential_code
|
||||
break
|
||||
# Check if an *active* invite with this code already exists (outside main transaction for now)
|
||||
# Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError.
|
||||
# This check reduces collision chance before attempting transaction.
|
||||
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||
existing_result = await db.execute(existing_check_stmt)
|
||||
if existing_result.scalar_one_or_none() is None:
|
||||
break # Found a potentially unique code
|
||||
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
||||
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
||||
|
||||
if code is None:
|
||||
# Failed to generate a unique code after several attempts
|
||||
return None
|
||||
try:
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
# Final check within transaction to be absolutely sure before insert
|
||||
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
||||
final_check_result = await db.execute(final_check_stmt)
|
||||
if final_check_result.scalar_one_or_none() is not None:
|
||||
# Extremely unlikely if previous check passed, but handles race condition
|
||||
await transaction.rollback()
|
||||
raise InviteOperationError("Invite code collision detected during transaction.")
|
||||
|
||||
db_invite = InviteModel(
|
||||
code=code,
|
||||
group_id=group_id,
|
||||
created_by_id=creator_id,
|
||||
expires_at=expires_at,
|
||||
is_active=True
|
||||
)
|
||||
db.add(db_invite)
|
||||
await db.commit()
|
||||
await db.refresh(db_invite)
|
||||
return db_invite
|
||||
db_invite = InviteModel(
|
||||
code=potential_code,
|
||||
group_id=group_id,
|
||||
created_by_id=creator_id,
|
||||
expires_at=expires_at,
|
||||
is_active=True
|
||||
)
|
||||
db.add(db_invite)
|
||||
await db.flush() # Assigns ID
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(InviteModel)
|
||||
.where(InviteModel.id == db_invite.id)
|
||||
.options(
|
||||
selectinload(InviteModel.group),
|
||||
selectinload(InviteModel.creator)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_invite = result.scalar_one_or_none()
|
||||
|
||||
if loaded_invite is None:
|
||||
await transaction.rollback()
|
||||
raise InviteOperationError("Failed to load invite after creation.")
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_invite
|
||||
except IntegrityError as e: # Catch if DB unique constraint on code is violated
|
||||
# Rollback handled by context manager
|
||||
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}")
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
|
||||
|
||||
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
||||
"""Gets an active and non-expired invite by its code."""
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await db.execute(
|
||||
select(InviteModel).where(
|
||||
InviteModel.code == code,
|
||||
InviteModel.is_active == True,
|
||||
InviteModel.expires_at > now
|
||||
try:
|
||||
stmt = (
|
||||
select(InviteModel).where(
|
||||
InviteModel.code == code,
|
||||
InviteModel.is_active == True,
|
||||
InviteModel.expires_at > now
|
||||
)
|
||||
.options(
|
||||
selectinload(InviteModel.group),
|
||||
selectinload(InviteModel.creator)
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
|
||||
|
||||
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
|
||||
"""Marks an invite as inactive (used)."""
|
||||
invite.is_active = False
|
||||
db.add(invite) # Add to session to track change
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
return invite
|
||||
"""Marks an invite as inactive (used) and reloads with relationships."""
|
||||
try:
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
invite.is_active = False
|
||||
db.add(invite) # Add to session to track change
|
||||
await db.flush() # Persist is_active change
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(InviteModel)
|
||||
.where(InviteModel.id == invite.id)
|
||||
.options(
|
||||
selectinload(InviteModel.group),
|
||||
selectinload(InviteModel.creator)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
updated_invite = result.scalar_one_or_none()
|
||||
|
||||
if updated_invite is None: # Should not happen as invite is passed in
|
||||
await transaction.rollback()
|
||||
raise InviteOperationError("Failed to load invite after deactivation.")
|
||||
|
||||
await transaction.commit()
|
||||
return updated_invite
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
||||
|
||||
# Ensure InviteOperationError is defined in app.core.exceptions
|
||||
# Example: class InviteOperationError(AppException): pass
|
||||
|
||||
# Optional: Function to periodically delete old, inactive invites
|
||||
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...
|
@ -1,12 +1,13 @@
|
||||
# app/crud/item.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||
from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||
from typing import Optional, List as PyList
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.models import Item as ItemModel
|
||||
from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload
|
||||
from app.schemas.item import ItemCreate, ItemUpdate
|
||||
from app.core.exceptions import (
|
||||
ItemNotFoundError,
|
||||
@ -14,46 +15,65 @@ from app.core.exceptions import (
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
ConflictError
|
||||
ConflictError,
|
||||
ItemOperationError # Add if specific item operation errors are needed
|
||||
)
|
||||
|
||||
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
||||
"""Creates a new item record for a specific list."""
|
||||
try:
|
||||
db_item = ItemModel(
|
||||
name=item_in.name,
|
||||
quantity=item_in.quantity,
|
||||
list_id=list_id,
|
||||
added_by_id=user_id,
|
||||
is_complete=False # Default on creation
|
||||
# version is implicitly set to 1 by model default
|
||||
)
|
||||
db.add(db_item)
|
||||
await db.flush()
|
||||
await db.refresh(db_item)
|
||||
await db.commit() # Explicitly commit here
|
||||
return db_item
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
db_item = ItemModel(
|
||||
name=item_in.name,
|
||||
quantity=item_in.quantity,
|
||||
list_id=list_id,
|
||||
added_by_id=user_id,
|
||||
is_complete=False
|
||||
)
|
||||
db.add(db_item)
|
||||
await db.flush() # Assigns ID
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(ItemModel)
|
||||
.where(ItemModel.id == db_item.id)
|
||||
.options(
|
||||
selectinload(ItemModel.added_by_user),
|
||||
selectinload(ItemModel.completed_by_user) # Will be None but loads relationship
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_item = result.scalar_one_or_none()
|
||||
|
||||
if loaded_item is None:
|
||||
await transaction.rollback() # Should be handled by context manager on raise, but explicit for clarity
|
||||
raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_item
|
||||
except IntegrityError as e:
|
||||
await db.rollback() # Rollback on integrity error
|
||||
# Context manager handles rollback on error
|
||||
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
|
||||
except OperationalError as e:
|
||||
await db.rollback() # Rollback on operational error
|
||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
await db.rollback() # Rollback on other SQLAlchemy errors
|
||||
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
|
||||
except Exception as e: # Catch any other exception and attempt rollback
|
||||
await db.rollback()
|
||||
raise # Re-raise the original exception
|
||||
# Removed generic Exception block as SQLAlchemyError should cover DB issues,
|
||||
# and context manager handles rollback.
|
||||
|
||||
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
|
||||
"""Gets all items belonging to a specific list, ordered by creation time."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
stmt = (
|
||||
select(ItemModel)
|
||||
.where(ItemModel.list_id == list_id)
|
||||
.order_by(ItemModel.created_at.asc()) # Or desc() if preferred
|
||||
.options(
|
||||
selectinload(ItemModel.added_by_user),
|
||||
selectinload(ItemModel.completed_by_user)
|
||||
)
|
||||
.order_by(ItemModel.created_at.asc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||
@ -63,7 +83,16 @@ async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemMod
|
||||
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
||||
"""Gets a single item by its ID."""
|
||||
try:
|
||||
result = await db.execute(select(ItemModel).where(ItemModel.id == item_id))
|
||||
stmt = (
|
||||
select(ItemModel)
|
||||
.where(ItemModel.id == item_id)
|
||||
.options(
|
||||
selectinload(ItemModel.added_by_user),
|
||||
selectinload(ItemModel.completed_by_user),
|
||||
selectinload(ItemModel.list) # Often useful to get the parent list
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||
@ -73,59 +102,72 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
||||
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
||||
"""Updates an existing item record, checking for version conflicts."""
|
||||
try:
|
||||
# Check version conflict
|
||||
if item_db.version != item_in.version:
|
||||
raise ConflictError(
|
||||
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
|
||||
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
if item_db.version != item_in.version:
|
||||
# No need to rollback here, as the transaction hasn't committed.
|
||||
# The context manager will handle rollback if an exception is raised.
|
||||
raise ConflictError(
|
||||
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
|
||||
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
|
||||
)
|
||||
|
||||
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
|
||||
|
||||
if 'is_complete' in update_data:
|
||||
if update_data['is_complete'] is True:
|
||||
if item_db.completed_by_id is None:
|
||||
update_data['completed_by_id'] = user_id
|
||||
else:
|
||||
update_data['completed_by_id'] = None
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(item_db, key, value)
|
||||
|
||||
item_db.version += 1
|
||||
db.add(item_db) # Mark as dirty
|
||||
await db.flush()
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(ItemModel)
|
||||
.where(ItemModel.id == item_db.id)
|
||||
.options(
|
||||
selectinload(ItemModel.added_by_user),
|
||||
selectinload(ItemModel.completed_by_user),
|
||||
selectinload(ItemModel.list)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
updated_item = result.scalar_one_or_none()
|
||||
|
||||
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) # Exclude version
|
||||
if updated_item is None: # Should not happen
|
||||
# Rollback will be handled by context manager on raise
|
||||
raise ItemOperationError("Failed to load item after update.")
|
||||
|
||||
# Special handling for is_complete
|
||||
if 'is_complete' in update_data:
|
||||
if update_data['is_complete'] is True:
|
||||
if item_db.completed_by_id is None: # Only set if not already completed by someone
|
||||
update_data['completed_by_id'] = user_id
|
||||
else:
|
||||
update_data['completed_by_id'] = None # Clear if marked incomplete
|
||||
|
||||
# Apply updates
|
||||
for key, value in update_data.items():
|
||||
setattr(item_db, key, value)
|
||||
|
||||
item_db.version += 1 # Increment version
|
||||
|
||||
db.add(item_db)
|
||||
await db.flush()
|
||||
await db.refresh(item_db)
|
||||
|
||||
# Commit the transaction if not part of a larger transaction
|
||||
await db.commit()
|
||||
|
||||
return item_db
|
||||
await transaction.commit()
|
||||
return updated_item
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
||||
except OperationalError as e:
|
||||
await db.rollback()
|
||||
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
||||
except ConflictError: # Re-raise ConflictError
|
||||
await db.rollback()
|
||||
except ConflictError: # Re-raise ConflictError, rollback handled by context manager
|
||||
raise
|
||||
except SQLAlchemyError as e:
|
||||
await db.rollback()
|
||||
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
||||
|
||||
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
||||
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
|
||||
try:
|
||||
await db.delete(item_db)
|
||||
await db.commit()
|
||||
return None
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
await db.delete(item_db)
|
||||
await transaction.commit()
|
||||
# No return needed for None
|
||||
except OperationalError as e:
|
||||
await db.rollback()
|
||||
# Rollback handled by context manager
|
||||
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
await db.rollback()
|
||||
# Rollback handled by context manager
|
||||
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}")
|
||||
|
||||
# Ensure ItemOperationError is defined in app.core.exceptions if used
|
||||
# Example: class ItemOperationError(AppException): pass
|
@ -17,15 +17,14 @@ from app.core.exceptions import (
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
ConflictError
|
||||
ConflictError,
|
||||
ListOperationError
|
||||
)
|
||||
|
||||
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
||||
"""Creates a new list record."""
|
||||
try:
|
||||
# Check if we're already in a transaction
|
||||
if db.in_transaction():
|
||||
# If we're already in a transaction, just create the list
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
db_list = ListModel(
|
||||
name=list_in.name,
|
||||
description=list_in.description,
|
||||
@ -34,23 +33,27 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) ->
|
||||
is_complete=False
|
||||
)
|
||||
db.add(db_list)
|
||||
await db.flush()
|
||||
await db.refresh(db_list)
|
||||
return db_list
|
||||
else:
|
||||
# If no transaction is active, start one
|
||||
async with db.begin():
|
||||
db_list = ListModel(
|
||||
name=list_in.name,
|
||||
description=list_in.description,
|
||||
group_id=list_in.group_id,
|
||||
created_by_id=creator_id,
|
||||
is_complete=False
|
||||
await db.flush() # Assigns ID
|
||||
|
||||
# Re-fetch with relationships for the response
|
||||
stmt = (
|
||||
select(ListModel)
|
||||
.where(ListModel.id == db_list.id)
|
||||
.options(
|
||||
selectinload(ListModel.creator),
|
||||
selectinload(ListModel.group)
|
||||
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
||||
)
|
||||
db.add(db_list)
|
||||
await db.flush()
|
||||
await db.refresh(db_list)
|
||||
return db_list
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_list = result.scalar_one_or_none()
|
||||
|
||||
if loaded_list is None:
|
||||
await transaction.rollback()
|
||||
raise ListOperationError("Failed to load list after creation.")
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_list
|
||||
except IntegrityError as e:
|
||||
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
|
||||
except OperationalError as e:
|
||||
@ -66,14 +69,22 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
|
||||
)
|
||||
user_group_ids = group_ids_result.scalars().all()
|
||||
|
||||
# Build conditions for the OR clause dynamically
|
||||
conditions = [
|
||||
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
|
||||
]
|
||||
if user_group_ids: # Only add the IN clause if there are group IDs
|
||||
if user_group_ids:
|
||||
conditions.append(ListModel.group_id.in_(user_group_ids))
|
||||
|
||||
query = select(ListModel).where(or_(*conditions)).order_by(ListModel.updated_at.desc())
|
||||
query = (
|
||||
select(ListModel)
|
||||
.where(or_(*conditions))
|
||||
.options(
|
||||
selectinload(ListModel.creator),
|
||||
selectinload(ListModel.group)
|
||||
# selectinload(ListModel.items) # Consider if items are needed for list previews
|
||||
)
|
||||
.order_by(ListModel.updated_at.desc())
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
@ -85,11 +96,17 @@ async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel
|
||||
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
|
||||
"""Gets a single list by ID, optionally loading its items."""
|
||||
try:
|
||||
query = select(ListModel).where(ListModel.id == list_id)
|
||||
query = (
|
||||
select(ListModel)
|
||||
.where(ListModel.id == list_id)
|
||||
.options(
|
||||
selectinload(ListModel.creator),
|
||||
selectinload(ListModel.group)
|
||||
)
|
||||
)
|
||||
if load_items:
|
||||
query = query.options(
|
||||
selectinload(ListModel.items)
|
||||
.options(
|
||||
selectinload(ListModel.items).options(
|
||||
joinedload(ItemModel.added_by_user),
|
||||
joinedload(ItemModel.completed_by_user)
|
||||
)
|
||||
@ -104,8 +121,9 @@ async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = Fals
|
||||
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
||||
"""Updates an existing list record, checking for version conflicts."""
|
||||
try:
|
||||
async with db.begin():
|
||||
if list_db.version != list_in.version:
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
|
||||
await transaction.rollback() # Rollback before raising
|
||||
raise ConflictError(
|
||||
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
||||
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
||||
@ -118,34 +136,54 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate)
|
||||
|
||||
list_db.version += 1
|
||||
|
||||
db.add(list_db)
|
||||
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
|
||||
await db.flush()
|
||||
await db.refresh(list_db)
|
||||
return list_db
|
||||
|
||||
# Re-fetch with relationships for the response
|
||||
stmt = (
|
||||
select(ListModel)
|
||||
.where(ListModel.id == list_db.id)
|
||||
.options(
|
||||
selectinload(ListModel.creator),
|
||||
selectinload(ListModel.group)
|
||||
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
updated_list = result.scalar_one_or_none()
|
||||
|
||||
if updated_list is None: # Should not happen
|
||||
await transaction.rollback()
|
||||
raise ListOperationError("Failed to load list after update.")
|
||||
|
||||
await transaction.commit()
|
||||
return updated_list
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
# Ensure rollback if not handled by context manager (though it should be)
|
||||
if db.in_transaction(): await db.rollback()
|
||||
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
||||
except OperationalError as e:
|
||||
await db.rollback()
|
||||
if db.in_transaction(): await db.rollback()
|
||||
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
||||
except ConflictError:
|
||||
await db.rollback()
|
||||
# Already rolled back or will be by context manager if transaction was started here
|
||||
raise
|
||||
except SQLAlchemyError as e:
|
||||
await db.rollback()
|
||||
if db.in_transaction(): await db.rollback()
|
||||
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
||||
|
||||
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
||||
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
|
||||
try:
|
||||
async with db.begin():
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
|
||||
await db.delete(list_db)
|
||||
return None
|
||||
await transaction.commit() # Explicit commit
|
||||
# return None # Already implicitly returns None
|
||||
except OperationalError as e:
|
||||
await db.rollback()
|
||||
# Rollback should be handled by the async with block on exception
|
||||
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
await db.rollback()
|
||||
# Rollback should be handled by the async with block on exception
|
||||
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
||||
|
||||
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
||||
@ -212,39 +250,48 @@ async def get_list_by_name_and_group(
|
||||
db: AsyncSession,
|
||||
name: str,
|
||||
group_id: Optional[int],
|
||||
user_id: int
|
||||
user_id: int # user_id is for permission check, not direct list attribute
|
||||
) -> Optional[ListModel]:
|
||||
"""
|
||||
Gets a list by name and group, ensuring the user has permission to access it.
|
||||
Used for conflict resolution when creating lists.
|
||||
"""
|
||||
try:
|
||||
# Build the base query
|
||||
query = select(ListModel).where(ListModel.name == name)
|
||||
# Base query for the list itself
|
||||
base_query = select(ListModel).where(ListModel.name == name)
|
||||
|
||||
# Add group condition
|
||||
if group_id is not None:
|
||||
query = query.where(ListModel.group_id == group_id)
|
||||
base_query = base_query.where(ListModel.group_id == group_id)
|
||||
else:
|
||||
query = query.where(ListModel.group_id.is_(None))
|
||||
base_query = base_query.where(ListModel.group_id.is_(None))
|
||||
|
||||
# Add permission conditions
|
||||
conditions = [
|
||||
ListModel.created_by_id == user_id # User is creator
|
||||
]
|
||||
if group_id is not None:
|
||||
# User is member of the group
|
||||
conditions.append(
|
||||
and_(
|
||||
ListModel.group_id == group_id,
|
||||
ListModel.created_by_id != user_id # Not the creator
|
||||
)
|
||||
)
|
||||
# Add eager loading for common relationships
|
||||
base_query = base_query.options(
|
||||
selectinload(ListModel.creator),
|
||||
selectinload(ListModel.group)
|
||||
)
|
||||
|
||||
query = query.where(or_(*conditions))
|
||||
list_result = await db.execute(base_query)
|
||||
target_list = list_result.scalar_one_or_none()
|
||||
|
||||
if not target_list:
|
||||
return None
|
||||
|
||||
# Permission check
|
||||
is_creator = target_list.created_by_id == user_id
|
||||
|
||||
if is_creator:
|
||||
return target_list
|
||||
|
||||
if target_list.group_id:
|
||||
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
|
||||
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
|
||||
if is_member_of_group:
|
||||
return target_list
|
||||
|
||||
# If not creator and (not a group list or not a member of the group list)
|
||||
return None
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().first()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
|
@ -3,84 +3,135 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload, joinedload
|
||||
from sqlalchemy import or_
|
||||
from decimal import Decimal
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import List as PyList, Optional, Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.models import (
|
||||
Settlement as SettlementModel,
|
||||
User as UserModel,
|
||||
Group as GroupModel
|
||||
Group as GroupModel,
|
||||
UserGroup as UserGroupModel
|
||||
)
|
||||
from app.schemas.expense import SettlementCreate, SettlementUpdate
|
||||
from app.core.exceptions import (
|
||||
UserNotFoundError,
|
||||
GroupNotFoundError,
|
||||
InvalidOperationError,
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError,
|
||||
SettlementOperationError,
|
||||
ConflictError
|
||||
)
|
||||
from app.schemas.expense import SettlementCreate, SettlementUpdate # SettlementUpdate not used yet
|
||||
from app.core.exceptions import UserNotFoundError, GroupNotFoundError, InvalidOperationError
|
||||
|
||||
async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel:
|
||||
"""Creates a new settlement record."""
|
||||
|
||||
# Validate Payer, Payee, and Group exist
|
||||
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
||||
|
||||
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
|
||||
if not payee:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
|
||||
|
||||
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
|
||||
raise InvalidOperationError("Payer and Payee cannot be the same user.")
|
||||
|
||||
group = await db.get(GroupModel, settlement_in.group_id)
|
||||
if not group:
|
||||
raise GroupNotFoundError(settlement_in.group_id)
|
||||
|
||||
# Optional: Check if current_user_id is part of the group or is one of the parties involved
|
||||
# This is more of an API-level permission check but could be added here if strict.
|
||||
# For example: if current_user_id not in [settlement_in.paid_by_user_id, settlement_in.paid_to_user_id]:
|
||||
# is_in_group = await db.execute(select(UserGroupModel).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id))
|
||||
# if not is_in_group.first():
|
||||
# raise InvalidOperationError("You can only record settlements you are part of or for groups you belong to.")
|
||||
|
||||
db_settlement = SettlementModel(
|
||||
group_id=settlement_in.group_id,
|
||||
paid_by_user_id=settlement_in.paid_by_user_id,
|
||||
paid_to_user_id=settlement_in.paid_to_user_id,
|
||||
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
|
||||
description=settlement_in.description
|
||||
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
|
||||
)
|
||||
db.add(db_settlement)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(db_settlement, attribute_names=["payer", "payee", "group"])
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to save settlement: {str(e)}")
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
payer = await db.get(UserModel, settlement_in.paid_by_user_id)
|
||||
if not payer:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_by_user_id, identifier="Payer")
|
||||
|
||||
payee = await db.get(UserModel, settlement_in.paid_to_user_id)
|
||||
if not payee:
|
||||
raise UserNotFoundError(user_id=settlement_in.paid_to_user_id, identifier="Payee")
|
||||
|
||||
if settlement_in.paid_by_user_id == settlement_in.paid_to_user_id:
|
||||
raise InvalidOperationError("Payer and Payee cannot be the same user.")
|
||||
|
||||
group = await db.get(GroupModel, settlement_in.group_id)
|
||||
if not group:
|
||||
raise GroupNotFoundError(settlement_in.group_id)
|
||||
|
||||
# Permission check example (can be in API layer too)
|
||||
# if current_user_id not in [payer.id, payee.id]:
|
||||
# is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1)
|
||||
# is_member_result = await db.execute(is_member_stmt)
|
||||
# if not is_member_result.scalar_one_or_none():
|
||||
# raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.")
|
||||
|
||||
db_settlement = SettlementModel(
|
||||
group_id=settlement_in.group_id,
|
||||
paid_by_user_id=settlement_in.paid_by_user_id,
|
||||
paid_to_user_id=settlement_in.paid_to_user_id,
|
||||
amount=settlement_in.amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
||||
settlement_date=settlement_in.settlement_date if settlement_in.settlement_date else datetime.now(timezone.utc),
|
||||
description=settlement_in.description
|
||||
# created_by_user_id = current_user_id # Optional: Who recorded this settlement
|
||||
)
|
||||
db.add(db_settlement)
|
||||
await db.flush()
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.id == db_settlement.id)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_settlement = result.scalar_one_or_none()
|
||||
|
||||
if loaded_settlement is None:
|
||||
await transaction.rollback() # Should be handled by context manager
|
||||
raise SettlementOperationError("Failed to load settlement after creation.")
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_settlement
|
||||
except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
|
||||
# These are validation errors, re-raise them.
|
||||
# If a transaction was started, context manager handles rollback.
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
raise DatabaseIntegrityError(f"Failed to save settlement due to DB integrity: {str(e)}")
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error during settlement creation: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"DB transaction error during settlement creation: {str(e)}")
|
||||
|
||||
return db_settlement
|
||||
|
||||
async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional[SettlementModel]:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
.where(SettlementModel.id == settlement_id)
|
||||
)
|
||||
.where(SettlementModel.id == settlement_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
return result.scalars().first()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}")
|
||||
|
||||
async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.group_id == group_id)
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee))
|
||||
)
|
||||
return result.scalars().all()
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.group_id == group_id)
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error fetching group settlements: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"DB query error fetching group settlements: {str(e)}")
|
||||
|
||||
|
||||
async def get_settlements_involving_user(
|
||||
db: AsyncSession,
|
||||
@ -89,18 +140,28 @@ async def get_settlements_involving_user(
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> Sequence[SettlementModel]:
|
||||
query = (
|
||||
select(SettlementModel)
|
||||
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(selectinload(SettlementModel.payer), selectinload(SettlementModel.payee), selectinload(SettlementModel.group))
|
||||
)
|
||||
if group_id:
|
||||
query = query.where(SettlementModel.group_id == group_id)
|
||||
try:
|
||||
query = (
|
||||
select(SettlementModel)
|
||||
.where(or_(SettlementModel.paid_by_user_id == user_id, SettlementModel.paid_to_user_id == user_id))
|
||||
.order_by(SettlementModel.settlement_date.desc(), SettlementModel.created_at.desc())
|
||||
.offset(skip).limit(limit)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
)
|
||||
if group_id:
|
||||
query = query.where(SettlementModel.group_id == group_id)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error fetching user settlements: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseQueryError(f"DB query error fetching user settlements: {str(e)}")
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, settlement_in: SettlementUpdate) -> SettlementModel:
|
||||
"""
|
||||
@ -108,58 +169,100 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se
|
||||
Only allows updates to description and settlement_date.
|
||||
Requires version matching for optimistic locking.
|
||||
Assumes SettlementUpdate schema includes a version field.
|
||||
Assumes SettlementModel has version and updated_at fields.
|
||||
"""
|
||||
# Check if SettlementUpdate schema has 'version'. If not, this check needs to be adapted or version passed differently.
|
||||
if not hasattr(settlement_in, 'version') or settlement_db.version != settlement_in.version:
|
||||
raise InvalidOperationError(
|
||||
f"Settlement (ID: {settlement_db.id}) has been modified. "
|
||||
f"Your version does not match current version {settlement_db.version}. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
|
||||
allowed_to_update = {"description", "settlement_date"}
|
||||
updated_something = False
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field in allowed_to_update:
|
||||
setattr(settlement_db, field, value)
|
||||
updated_something = True
|
||||
else:
|
||||
raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed for settlements.")
|
||||
|
||||
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
|
||||
pass # No actual updatable fields provided, but version matched.
|
||||
|
||||
settlement_db.version += 1 # Assuming SettlementModel has a version field, add if missing
|
||||
settlement_db.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(settlement_db)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to update settlement: {str(e)}")
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
# Ensure the settlement_db passed is managed by the current session if not already.
|
||||
# This is usually true if fetched by an endpoint dependency using the same session.
|
||||
# If not, `db.add(settlement_db)` might be needed before modification if it's detached.
|
||||
|
||||
if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'):
|
||||
raise InvalidOperationError("Version field is missing in model or input for optimistic locking.")
|
||||
|
||||
if settlement_db.version != settlement_in.version:
|
||||
raise ConflictError( # Make sure ConflictError is defined in exceptions
|
||||
f"Settlement (ID: {settlement_db.id}) has been modified. "
|
||||
f"Your version {settlement_in.version} does not match current version {settlement_db.version}. Please refresh."
|
||||
)
|
||||
|
||||
update_data = settlement_in.model_dump(exclude_unset=True, exclude={"version"})
|
||||
allowed_to_update = {"description", "settlement_date"}
|
||||
updated_something = False
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field in allowed_to_update:
|
||||
setattr(settlement_db, field, value)
|
||||
updated_something = True
|
||||
# Silently ignore fields not allowed to update or raise error:
|
||||
# else:
|
||||
# raise InvalidOperationError(f"Field '{field}' cannot be updated.")
|
||||
|
||||
if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update):
|
||||
# No updatable fields were actually provided, or they didn't change
|
||||
# Still, we might want to return the re-loaded settlement if version matched.
|
||||
pass
|
||||
|
||||
settlement_db.version += 1
|
||||
settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field
|
||||
|
||||
db.add(settlement_db) # Mark as dirty
|
||||
await db.flush()
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(SettlementModel)
|
||||
.where(SettlementModel.id == settlement_db.id)
|
||||
.options(
|
||||
selectinload(SettlementModel.payer),
|
||||
selectinload(SettlementModel.payee),
|
||||
selectinload(SettlementModel.group)
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
updated_settlement = result.scalar_one_or_none()
|
||||
|
||||
if updated_settlement is None: # Should not happen
|
||||
await transaction.rollback()
|
||||
raise SettlementOperationError("Failed to load settlement after update.")
|
||||
|
||||
await transaction.commit()
|
||||
return updated_settlement
|
||||
except ConflictError as e: # ConflictError should be defined in exceptions
|
||||
raise
|
||||
except InvalidOperationError as e:
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
raise DatabaseIntegrityError(f"Failed to update settlement due to DB integrity: {str(e)}")
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error during settlement update: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"DB transaction error during settlement update: {str(e)}")
|
||||
|
||||
return settlement_db
|
||||
|
||||
async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, expected_version: Optional[int] = None) -> None:
|
||||
"""
|
||||
Deletes a settlement. Requires version matching if expected_version is provided.
|
||||
Assumes SettlementModel has a version field.
|
||||
"""
|
||||
if expected_version is not None:
|
||||
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
||||
raise InvalidOperationError(
|
||||
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
|
||||
f"Expected version {expected_version} does not match current version. Please refresh.",
|
||||
# status_code=status.HTTP_409_CONFLICT
|
||||
)
|
||||
|
||||
await db.delete(settlement_db)
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise InvalidOperationError(f"Failed to delete settlement: {str(e)}")
|
||||
return None
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
if expected_version is not None:
|
||||
if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version:
|
||||
raise ConflictError( # Make sure ConflictError is defined
|
||||
f"Settlement (ID: {settlement_db.id}) cannot be deleted. "
|
||||
f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh."
|
||||
)
|
||||
|
||||
await db.delete(settlement_db)
|
||||
await transaction.commit()
|
||||
except ConflictError as e: # ConflictError should be defined
|
||||
raise
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"DB connection error during settlement deletion: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}")
|
||||
|
||||
# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions
|
||||
# Example: class SettlementOperationError(AppException): pass
|
||||
# Example: class ConflictError(AppException): status_code = 409
|
@ -1,10 +1,11 @@
|
||||
# app/crud/user.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
||||
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
||||
from typing import Optional
|
||||
|
||||
from app.models import User as UserModel # Alias to avoid name clash
|
||||
from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload
|
||||
from app.schemas.user import UserCreate
|
||||
from app.core.security import hash_password
|
||||
from app.core.exceptions import (
|
||||
@ -13,14 +14,26 @@ from app.core.exceptions import (
|
||||
DatabaseConnectionError,
|
||||
DatabaseIntegrityError,
|
||||
DatabaseQueryError,
|
||||
DatabaseTransactionError
|
||||
DatabaseTransactionError,
|
||||
UserOperationError # Add if specific user operation errors are needed
|
||||
)
|
||||
|
||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]:
|
||||
"""Fetches a user from the database by email."""
|
||||
"""Fetches a user from the database by email, with common relationships."""
|
||||
try:
|
||||
async with db.begin():
|
||||
result = await db.execute(select(UserModel).filter(UserModel.email == email))
|
||||
# db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added.
|
||||
# For a single select, it can be omitted if preferred, session handles connection.
|
||||
async with db.begin(): # Or remove if only a single select operation
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.filter(UserModel.email == email)
|
||||
.options(
|
||||
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of
|
||||
selectinload(UserModel.created_groups) # Groups user created
|
||||
# Add other relationships as needed by UserPublic schema
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
||||
@ -28,24 +41,46 @@ async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]
|
||||
raise DatabaseQueryError(f"Failed to query user: {str(e)}")
|
||||
|
||||
async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel:
|
||||
"""Creates a new user record in the database."""
|
||||
"""Creates a new user record in the database with common relationships loaded."""
|
||||
try:
|
||||
async with db.begin():
|
||||
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
||||
_hashed_password = hash_password(user_in.password)
|
||||
db_user = UserModel(
|
||||
email=user_in.email,
|
||||
password_hash=_hashed_password,
|
||||
hashed_password=_hashed_password, # Field name in model is hashed_password
|
||||
name=user_in.name
|
||||
)
|
||||
db.add(db_user)
|
||||
await db.flush() # Flush to get DB-generated values
|
||||
await db.refresh(db_user)
|
||||
return db_user
|
||||
await db.flush() # Flush to get DB-generated values like ID
|
||||
|
||||
# Re-fetch with relationships
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(UserModel.id == db_user.id)
|
||||
.options(
|
||||
selectinload(UserModel.group_associations).selectinload(UserGroupModel.group),
|
||||
selectinload(UserModel.created_groups)
|
||||
# Add other relationships as needed by UserPublic schema
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
loaded_user = result.scalar_one_or_none()
|
||||
|
||||
if loaded_user is None:
|
||||
await transaction.rollback() # Should be handled by context manager, but explicit
|
||||
raise UserOperationError("Failed to load user after creation.") # Define UserOperationError
|
||||
|
||||
await transaction.commit()
|
||||
return loaded_user
|
||||
except IntegrityError as e:
|
||||
if "unique constraint" in str(e).lower():
|
||||
raise EmailAlreadyRegisteredError()
|
||||
raise DatabaseIntegrityError(f"Failed to create user: {str(e)}")
|
||||
# Context manager handles rollback on error
|
||||
if "unique constraint" in str(e).lower() and ("users_email_key" in str(e).lower() or "ix_users_email" in str(e).lower()):
|
||||
raise EmailAlreadyRegisteredError(email=user_in.email)
|
||||
raise DatabaseIntegrityError(f"Failed to create user due to integrity issue: {str(e)}")
|
||||
except OperationalError as e:
|
||||
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
||||
raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}")
|
||||
except SQLAlchemyError as e:
|
||||
raise DatabaseTransactionError(f"Failed to create user: {str(e)}")
|
||||
raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}")
|
||||
|
||||
# Ensure UserOperationError is defined in app.core.exceptions if used
|
||||
# Example: class UserOperationError(AppException): pass
|
@ -36,15 +36,9 @@ async def get_async_session() -> AsyncSession: # type: ignore
|
||||
Ensures the session is closed after the request.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
# Commit the transaction if no errors occurred
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close() # Not strictly necessary with async context manager, but explicit
|
||||
yield session
|
||||
# The 'async with' block handles session.close() automatically.
|
||||
# Commit/rollback should be handled by the functions using the session.
|
||||
|
||||
# Alias for backward compatibility
|
||||
get_db = get_async_session
|
Loading…
Reference in New Issue
Block a user