mitlist/be/app/crud/list.py

298 lines
12 KiB
Python

# app/crud/list.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
from app.schemas.list import ListCreate, ListUpdate
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ListOperationError
)
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
db_list = ListModel(
name=list_in.name,
description=list_in.description,
group_id=list_in.group_id,
created_by_id=creator_id,
is_complete=False
)
db.add(db_list)
await db.flush() # Assigns ID
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == db_list.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
loaded_list = result.scalar_one_or_none()
if loaded_list is None:
await transaction.rollback()
raise ListOperationError("Failed to load list after creation.")
await transaction.commit()
return loaded_list
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
"""Gets all lists accessible by a user."""
try:
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
]
if user_group_ids:
conditions.append(ListModel.group_id.in_(user_group_ids))
query = (
select(ListModel)
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Consider if items are needed for list previews
)
.order_by(ListModel.updated_at.desc())
)
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items."""
try:
query = (
select(ListModel)
.where(ListModel.id == list_id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
)
if load_items:
query = query.options(
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
await transaction.rollback() # Rollback before raising
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items():
setattr(list_db, key, value)
list_db.version += 1
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
await db.flush()
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == list_db.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
await transaction.rollback()
raise ListOperationError("Failed to load list after update.")
await transaction.commit()
return updated_list
except IntegrityError as e:
# Ensure rollback if not handled by context manager (though it should be)
if db.in_transaction(): await db.rollback()
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
if db.in_transaction(): await db.rollback()
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
# Already rolled back or will be by context manager if transaction was started here
raise
except SQLAlchemyError as e:
if db.in_transaction(): await db.rollback()
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db)
await transaction.commit() # Explicit commit
# return None # Already implicitly returns None
except OperationalError as e:
# Rollback should be handled by the async with block on exception
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
# Rollback should be handled by the async with block on exception
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
"""Fetches a list and verifies user permission."""
try:
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
if not list_db:
raise ListNotFoundError(list_id)
is_creator = list_db.created_by_id == user_id
if require_creator:
if not is_creator:
raise ListCreatorRequiredError(list_id, "access")
return list_db
if is_creator:
return list_db
if list_db.group_id:
from app.crud.group import is_user_member
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
if not is_member:
raise ListPermissionError(list_id)
return list_db
else:
raise ListPermissionError(list_id)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
"""Gets the update timestamps and item count for a list."""
try:
list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
list_result = await db.execute(list_query)
list_updated_at = list_result.scalar_one_or_none()
if list_updated_at is None:
raise ListNotFoundError(list_id)
item_status_query = (
select(
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
sql_func.count(ItemModel.id).label("item_count")
)
.where(ItemModel.list_id == list_id)
)
item_result = await db.execute(item_status_query)
item_status = item_result.first()
return ListStatus(
list_updated_at=list_updated_at,
latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
item_count=item_status.item_count if item_status else 0
)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
async def get_list_by_name_and_group(
db: AsyncSession,
name: str,
group_id: Optional[int],
user_id: int # user_id is for permission check, not direct list attribute
) -> Optional[ListModel]:
"""
Gets a list by name and group, ensuring the user has permission to access it.
Used for conflict resolution when creating lists.
"""
try:
# Base query for the list itself
base_query = select(ListModel).where(ListModel.name == name)
if group_id is not None:
base_query = base_query.where(ListModel.group_id == group_id)
else:
base_query = base_query.where(ListModel.group_id.is_(None))
# Add eager loading for common relationships
base_query = base_query.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
list_result = await db.execute(base_query)
target_list = list_result.scalar_one_or_none()
if not target_list:
return None
# Permission check
is_creator = target_list.created_by_id == user_id
if is_creator:
return target_list
if target_list.group_id:
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
if is_member_of_group:
return target_list
# If not creator and (not a group list or not a member of the group list)
return None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}")