298 lines
12 KiB
Python
298 lines
12 KiB
Python
# app/crud/list.py
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
|
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
|
from typing import Optional, List as PyList
|
|
|
|
from app.schemas.list import ListStatus
|
|
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
|
|
from app.schemas.list import ListCreate, ListUpdate
|
|
from app.core.exceptions import (
|
|
ListNotFoundError,
|
|
ListPermissionError,
|
|
ListCreatorRequiredError,
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
ConflictError,
|
|
ListOperationError
|
|
)
|
|
|
|
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
|
|
"""Creates a new list record."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
db_list = ListModel(
|
|
name=list_in.name,
|
|
description=list_in.description,
|
|
group_id=list_in.group_id,
|
|
created_by_id=creator_id,
|
|
is_complete=False
|
|
)
|
|
db.add(db_list)
|
|
await db.flush() # Assigns ID
|
|
|
|
# Re-fetch with relationships for the response
|
|
stmt = (
|
|
select(ListModel)
|
|
.where(ListModel.id == db_list.id)
|
|
.options(
|
|
selectinload(ListModel.creator),
|
|
selectinload(ListModel.group)
|
|
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_list = result.scalar_one_or_none()
|
|
|
|
if loaded_list is None:
|
|
await transaction.rollback()
|
|
raise ListOperationError("Failed to load list after creation.")
|
|
|
|
await transaction.commit()
|
|
return loaded_list
|
|
except IntegrityError as e:
|
|
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
|
|
|
|
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
|
|
"""Gets all lists accessible by a user."""
|
|
try:
|
|
group_ids_result = await db.execute(
|
|
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
|
|
)
|
|
user_group_ids = group_ids_result.scalars().all()
|
|
|
|
conditions = [
|
|
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
|
|
]
|
|
if user_group_ids:
|
|
conditions.append(ListModel.group_id.in_(user_group_ids))
|
|
|
|
query = (
|
|
select(ListModel)
|
|
.where(or_(*conditions))
|
|
.options(
|
|
selectinload(ListModel.creator),
|
|
selectinload(ListModel.group)
|
|
# selectinload(ListModel.items) # Consider if items are needed for list previews
|
|
)
|
|
.order_by(ListModel.updated_at.desc())
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
return result.scalars().all()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
|
|
|
|
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
|
|
"""Gets a single list by ID, optionally loading its items."""
|
|
try:
|
|
query = (
|
|
select(ListModel)
|
|
.where(ListModel.id == list_id)
|
|
.options(
|
|
selectinload(ListModel.creator),
|
|
selectinload(ListModel.group)
|
|
)
|
|
)
|
|
if load_items:
|
|
query = query.options(
|
|
selectinload(ListModel.items).options(
|
|
joinedload(ItemModel.added_by_user),
|
|
joinedload(ItemModel.completed_by_user)
|
|
)
|
|
)
|
|
result = await db.execute(query)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
|
|
|
|
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
|
|
"""Updates an existing list record, checking for version conflicts."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
|
|
await transaction.rollback() # Rollback before raising
|
|
raise ConflictError(
|
|
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
|
|
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
|
|
)
|
|
|
|
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
|
|
|
|
for key, value in update_data.items():
|
|
setattr(list_db, key, value)
|
|
|
|
list_db.version += 1
|
|
|
|
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
|
|
await db.flush()
|
|
|
|
# Re-fetch with relationships for the response
|
|
stmt = (
|
|
select(ListModel)
|
|
.where(ListModel.id == list_db.id)
|
|
.options(
|
|
selectinload(ListModel.creator),
|
|
selectinload(ListModel.group)
|
|
# selectinload(ListModel.items) # Optionally add if items are always needed in response
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
updated_list = result.scalar_one_or_none()
|
|
|
|
if updated_list is None: # Should not happen
|
|
await transaction.rollback()
|
|
raise ListOperationError("Failed to load list after update.")
|
|
|
|
await transaction.commit()
|
|
return updated_list
|
|
except IntegrityError as e:
|
|
# Ensure rollback if not handled by context manager (though it should be)
|
|
if db.in_transaction(): await db.rollback()
|
|
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
|
|
except OperationalError as e:
|
|
if db.in_transaction(): await db.rollback()
|
|
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
|
|
except ConflictError:
|
|
# Already rolled back or will be by context manager if transaction was started here
|
|
raise
|
|
except SQLAlchemyError as e:
|
|
if db.in_transaction(): await db.rollback()
|
|
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
|
|
|
|
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
|
|
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
|
|
await db.delete(list_db)
|
|
await transaction.commit() # Explicit commit
|
|
# return None # Already implicitly returns None
|
|
except OperationalError as e:
|
|
# Rollback should be handled by the async with block on exception
|
|
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
# Rollback should be handled by the async with block on exception
|
|
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
|
|
|
|
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
|
|
"""Fetches a list and verifies user permission."""
|
|
try:
|
|
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
|
|
if not list_db:
|
|
raise ListNotFoundError(list_id)
|
|
|
|
is_creator = list_db.created_by_id == user_id
|
|
|
|
if require_creator:
|
|
if not is_creator:
|
|
raise ListCreatorRequiredError(list_id, "access")
|
|
return list_db
|
|
|
|
if is_creator:
|
|
return list_db
|
|
|
|
if list_db.group_id:
|
|
from app.crud.group import is_user_member
|
|
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
|
|
if not is_member:
|
|
raise ListPermissionError(list_id)
|
|
return list_db
|
|
else:
|
|
raise ListPermissionError(list_id)
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
|
|
|
|
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
|
|
"""Gets the update timestamps and item count for a list."""
|
|
try:
|
|
list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
|
|
list_result = await db.execute(list_query)
|
|
list_updated_at = list_result.scalar_one_or_none()
|
|
|
|
if list_updated_at is None:
|
|
raise ListNotFoundError(list_id)
|
|
|
|
item_status_query = (
|
|
select(
|
|
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
|
|
sql_func.count(ItemModel.id).label("item_count")
|
|
)
|
|
.where(ItemModel.list_id == list_id)
|
|
)
|
|
item_result = await db.execute(item_status_query)
|
|
item_status = item_result.first()
|
|
|
|
return ListStatus(
|
|
list_updated_at=list_updated_at,
|
|
latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
|
|
item_count=item_status.item_count if item_status else 0
|
|
)
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
|
|
|
|
async def get_list_by_name_and_group(
|
|
db: AsyncSession,
|
|
name: str,
|
|
group_id: Optional[int],
|
|
user_id: int # user_id is for permission check, not direct list attribute
|
|
) -> Optional[ListModel]:
|
|
"""
|
|
Gets a list by name and group, ensuring the user has permission to access it.
|
|
Used for conflict resolution when creating lists.
|
|
"""
|
|
try:
|
|
# Base query for the list itself
|
|
base_query = select(ListModel).where(ListModel.name == name)
|
|
|
|
if group_id is not None:
|
|
base_query = base_query.where(ListModel.group_id == group_id)
|
|
else:
|
|
base_query = base_query.where(ListModel.group_id.is_(None))
|
|
|
|
# Add eager loading for common relationships
|
|
base_query = base_query.options(
|
|
selectinload(ListModel.creator),
|
|
selectinload(ListModel.group)
|
|
)
|
|
|
|
list_result = await db.execute(base_query)
|
|
target_list = list_result.scalar_one_or_none()
|
|
|
|
if not target_list:
|
|
return None
|
|
|
|
# Permission check
|
|
is_creator = target_list.created_by_id == user_id
|
|
|
|
if is_creator:
|
|
return target_list
|
|
|
|
if target_list.group_id:
|
|
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
|
|
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
|
|
if is_member_of_group:
|
|
return target_list
|
|
|
|
# If not creator and (not a group list or not a member of the group list)
|
|
return None
|
|
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}") |