mitlist/be/app/crud/invite.py

143 lines
6.5 KiB
Python

# app/crud/invite.py
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
from sqlalchemy import delete # Import delete statement
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from typing import Optional
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
InviteOperationError # Add new specific exception
)
# Invite codes should be reasonably unique, but handle potential collision
MAX_CODE_GENERATION_ATTEMPTS = 5
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
"""Creates a new invite code for a group."""
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
# Check if an *active* invite with this code already exists (outside main transaction for now)
# Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError.
# This check reduces collision chance before attempting transaction.
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break # Found a potentially unique code
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
# Final check within transaction to be absolutely sure before insert
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
# Extremely unlikely if previous check passed, but handles race condition
await transaction.rollback()
raise InviteOperationError("Invite code collision detected during transaction.")
db_invite = InviteModel(
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.flush() # Assigns ID
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
await transaction.rollback()
raise InviteOperationError("Failed to load invite after creation.")
await transaction.commit()
return loaded_invite
except IntegrityError as e: # Catch if DB unique constraint on code is violated
# Rollback handled by context manager
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}")
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used) and reloads with relationships."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
invite.is_active = False
db.add(invite) # Add to session to track change
await db.flush() # Persist is_active change
# Re-fetch with relationships
stmt = (
select(InviteModel)
.where(InviteModel.id == invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
updated_invite = result.scalar_one_or_none()
if updated_invite is None: # Should not happen as invite is passed in
await transaction.rollback()
raise InviteOperationError("Failed to load invite after deactivation.")
await transaction.commit()
return updated_invite
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
# Ensure InviteOperationError is defined in app.core.exceptions
# Example: class InviteOperationError(AppException): pass
# Optional: Function to periodically delete old, inactive invites
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...