mitlist/be/app/crud/list.py

352 lines
14 KiB
Python

# app/crud/list.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
import logging # Add logging import
from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
from app.schemas.list import ListCreate, ListUpdate
from app.core.exceptions import (
ListNotFoundError,
ListPermissionError,
ListCreatorRequiredError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
ConflictError,
ListOperationError
)
logger = logging.getLogger(__name__) # Initialize logger
async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
"""Creates a new list record."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
db_list = ListModel(
name=list_in.name,
description=list_in.description,
group_id=list_in.group_id,
created_by_id=creator_id,
is_complete=False
)
db.add(db_list)
await db.flush() # Assigns ID
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == db_list.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
loaded_list = result.scalar_one_or_none()
if loaded_list is None:
raise ListOperationError("Failed to load list after creation.")
return loaded_list
except IntegrityError as e:
logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to create list: {str(e)}")
async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
"""Gets all lists accessible by a user."""
try:
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
conditions = [
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
]
if user_group_ids:
conditions.append(ListModel.group_id.in_(user_group_ids))
query = (
select(ListModel)
.where(or_(*conditions))
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group),
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
.order_by(ListModel.updated_at.desc())
)
result = await db.execute(query)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")
async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
"""Gets a single list by ID, optionally loading its items."""
try:
query = (
select(ListModel)
.where(ListModel.id == list_id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
)
if load_items:
query = query.options(
selectinload(ListModel.items).options(
joinedload(ItemModel.added_by_user),
joinedload(ItemModel.completed_by_user)
)
)
result = await db.execute(query)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list: {str(e)}")
async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
"""Updates an existing list record, checking for version conflicts."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
raise ConflictError(
f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
)
update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})
for key, value in update_data.items():
setattr(list_db, key, value)
list_db.version += 1
db.add(list_db) # Add the already attached list_db to mark it dirty for the session
await db.flush()
# Re-fetch with relationships for the response
stmt = (
select(ListModel)
.where(ListModel.id == list_db.id)
.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
# selectinload(ListModel.items) # Optionally add if items are always needed in response
)
)
result = await db.execute(stmt)
updated_list = result.scalar_one_or_none()
if updated_list is None: # Should not happen
raise ListOperationError("Failed to load list after update.")
return updated_list
except IntegrityError as e:
logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
except ConflictError:
raise
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to update list: {str(e)}")
async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
"""Deletes a list record. Version check should be done by the caller (API endpoint)."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
await db.delete(list_db)
except OperationalError as e:
logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")
async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
"""Fetches a list and verifies user permission."""
try:
list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
if not list_db:
raise ListNotFoundError(list_id)
is_creator = list_db.created_by_id == user_id
if require_creator:
if not is_creator:
raise ListCreatorRequiredError(list_id, "access")
return list_db
if is_creator:
return list_db
if list_db.group_id:
from app.crud.group import is_user_member
is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
if not is_member:
raise ListPermissionError(list_id)
return list_db
else:
raise ListPermissionError(list_id)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")
async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
"""Gets the update timestamps and item count for a list."""
try:
query = (
select(
ListModel.updated_at,
sql_func.count(ItemModel.id).label("item_count"),
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
)
.select_from(ListModel)
.outerjoin(ItemModel, ItemModel.list_id == ListModel.id)
.where(ListModel.id == list_id)
.group_by(ListModel.id)
)
result = await db.execute(query)
status = result.first()
if status is None:
raise ListNotFoundError(list_id)
return ListStatus(
updated_at=status.updated_at,
item_count=status.item_count,
latest_item_updated_at=status.latest_item_updated_at
)
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get list status: {str(e)}")
async def get_list_by_name_and_group(
db: AsyncSession,
name: str,
group_id: Optional[int],
user_id: int # user_id is for permission check, not direct list attribute
) -> Optional[ListModel]:
"""
Gets a list by name and group, ensuring the user has permission to access it.
Used for conflict resolution when creating lists.
"""
try:
# Base query for the list itself
base_query = select(ListModel).where(ListModel.name == name)
if group_id is not None:
base_query = base_query.where(ListModel.group_id == group_id)
else:
base_query = base_query.where(ListModel.group_id.is_(None))
# Add eager loading for common relationships
base_query = base_query.options(
selectinload(ListModel.creator),
selectinload(ListModel.group)
)
list_result = await db.execute(base_query)
target_list = list_result.scalar_one_or_none()
if not target_list:
return None
# Permission check
is_creator = target_list.created_by_id == user_id
if is_creator:
return target_list
if target_list.group_id:
from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
if is_member_of_group:
return target_list
# If not creator and (not a group list or not a member of the group list)
return None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}")
async def get_lists_statuses_by_ids(db: AsyncSession, list_ids: PyList[int], user_id: int) -> PyList[ListModel]:
"""
Gets status for a list of lists if the user has permission.
Status includes list updated_at and a count of its items.
"""
if not list_ids:
return []
try:
# First, get the groups the user is a member of
group_ids_result = await db.execute(
select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
)
user_group_ids = group_ids_result.scalars().all()
# Build the permission logic
permission_filter = or_(
# User is the creator of the list
and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)),
# List belongs to a group the user is a member of
ListModel.group_id.in_(user_group_ids)
)
# Main query to get list data and item counts
query = (
select(
ListModel.id,
ListModel.updated_at,
sql_func.count(ItemModel.id).label("item_count"),
sql_func.max(ItemModel.updated_at).label("latest_item_updated_at")
)
.outerjoin(ItemModel, ListModel.id == ItemModel.list_id)
.where(
and_(
ListModel.id.in_(list_ids),
permission_filter
)
)
.group_by(ListModel.id)
)
result = await db.execute(query)
# The result will be rows of (id, updated_at, item_count).
# We need to verify that all requested list_ids that the user *should* have access to are present.
# The filter in the query already handles permissions.
return result.all() # Returns a list of Row objects with id, updated_at, item_count
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to get lists statuses: {str(e)}")