mitlist/be/app/crud/group.py

269 lines
12 KiB
Python

# app/crud/group.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # For eager loading members
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List
from sqlalchemy import delete, func
from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel
from app.schemas.group import GroupCreate
from app.models import UserRoleEnum # Import enum
from app.core.exceptions import (
GroupOperationError,
GroupNotFoundError,
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
GroupMembershipError,
GroupPermissionError # Import GroupPermissionError
)
# --- Group CRUD ---
async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel:
"""Creates a group and adds the creator as the owner."""
try:
# Defensive check: if a transaction is already active, try to roll it back.
# This is unusual and suggests an issue upstream (e.g., middleware or session configuration).
if db.in_transaction():
# Log this occurrence if possible, as it's unexpected.
# import logging; logging.warning("Transaction already active on session entering create_group. Attempting rollback.")
try:
await db.rollback() # Attempt to clear any existing transaction
except SQLAlchemyError as e_rb:
# Log e_rb if possible
# import logging; logging.error(f"Error rolling back pre-existing transaction: {e_rb}")
# Re-raise or handle as a critical error, as the session state is uncertain.
raise DatabaseTransactionError(f"Session had an active transaction that could not be rolled back: {e_rb}")
async with db.begin(): # Now attempt to start a clean transaction
db_group = GroupModel(name=group_in.name, created_by_id=creator_id)
db.add(db_group)
await db.flush() # Assigns ID to db_group
db_user_group = UserGroupModel(
user_id=creator_id,
group_id=db_group.id,
role=UserRoleEnum.owner
)
db.add(db_user_group)
await db.flush() # Commits user_group, links to group
# After creation and linking, explicitly load the group with its member associations and users
stmt = (
select(GroupModel)
.where(GroupModel.id == db_group.id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
result = await db.execute(stmt)
loaded_group = result.scalar_one_or_none()
if loaded_group is None:
# This should not happen if we just created it, but as a safeguard
raise GroupOperationError("Failed to load group after creation.")
return loaded_group
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to create group due to integrity issue: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error during group creation: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Database transaction error during group creation: {str(e)}")
async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]:
"""Gets all groups a user is a member of."""
try:
result = await db.execute(
select(GroupModel)
.join(UserGroupModel)
.where(UserGroupModel.user_id == user_id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
return result.scalars().all()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user groups: {str(e)}")
async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]:
"""Gets a single group by its ID, optionally loading members."""
try:
result = await db.execute(
select(GroupModel)
.where(GroupModel.id == group_id)
.options(
selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user)
)
)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query group: {str(e)}")
async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Checks if a user is a member of a specific group."""
try:
result = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]:
"""Gets the role of a user in a specific group."""
try:
result = await db.execute(
select(UserGroupModel.role)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
)
return result.scalar_one_or_none()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to query user role: {str(e)}")
async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]:
"""Adds a user to a group if they aren't already a member."""
try:
# Check if user is already a member before starting a transaction
existing_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
existing_result = await db.execute(existing_stmt)
if existing_result.scalar_one_or_none():
return None
# Use a single transaction
async with db.begin():
db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role)
db.add(db_user_group)
await db.flush() # Assigns ID to db_user_group
# Eagerly load the 'user' and 'group' relationships for the response
stmt = (
select(UserGroupModel)
.where(UserGroupModel.id == db_user_group.id)
.options(
selectinload(UserGroupModel.user),
selectinload(UserGroupModel.group)
)
)
result = await db.execute(stmt)
loaded_user_group = result.scalar_one_or_none()
if loaded_user_group is None:
raise GroupOperationError(f"Failed to load user group association after adding user {user_id} to group {group_id}.")
return loaded_user_group
except IntegrityError as e:
raise DatabaseIntegrityError(f"Failed to add user to group: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to add user to group: {str(e)}")
async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool:
"""Removes a user from a group."""
try:
async with db.begin():
result = await db.execute(
delete(UserGroupModel)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.returning(UserGroupModel.id)
)
return result.scalar_one_or_none() is not None
except OperationalError as e:
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"Failed to remove user from group: {str(e)}")
async def get_group_member_count(db: AsyncSession, group_id: int) -> int:
"""Counts the number of members in a group."""
try:
result = await db.execute(
select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id)
)
return result.scalar_one()
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to count group members: {str(e)}")
async def check_group_membership(
db: AsyncSession,
group_id: int,
user_id: int,
action: str = "access this group"
) -> None:
"""
Checks if a user is a member of a group. Raises exceptions if not found or not a member.
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
"""
try:
# Check group existence first
group_exists = await db.get(GroupModel, group_id)
if not group_exists:
raise GroupNotFoundError(group_id)
# Check membership
membership = await db.execute(
select(UserGroupModel.id)
.where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id)
.limit(1)
)
if membership.scalar_one_or_none() is None:
raise GroupMembershipError(group_id, action=action)
# If we reach here, the user is a member
return None
except GroupNotFoundError: # Re-raise specific errors
raise
except GroupMembershipError:
raise
except OperationalError as e:
raise DatabaseConnectionError(f"Failed to connect to database while checking membership: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"Failed to check group membership: {str(e)}")
async def check_user_role_in_group(
db: AsyncSession,
group_id: int,
user_id: int,
required_role: UserRoleEnum,
action: str = "perform this action"
) -> None:
"""
Checks if a user is a member of a group and has the required role (or higher).
Raises:
GroupNotFoundError: If the group_id does not exist.
GroupMembershipError: If the user_id is not a member of the group.
GroupPermissionError: If the user does not have the required role.
"""
# First, ensure user is a member (this also checks group existence)
await check_group_membership(db, group_id, user_id, action=f"be checked for permissions to {action}")
# Get the user's actual role
actual_role = await get_user_role_in_group(db, group_id, user_id)
# Define role hierarchy (assuming owner > member)
role_hierarchy = {UserRoleEnum.owner: 2, UserRoleEnum.member: 1}
if not actual_role or role_hierarchy.get(actual_role, 0) < role_hierarchy.get(required_role, 0):
raise GroupPermissionError(
group_id=group_id,
action=f"{action} (requires at least '{required_role.value}' role)"
)
# If role is sufficient, return None
return None