143 lines
6.5 KiB
Python
143 lines
6.5 KiB
Python
# app/crud/invite.py
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
|
from sqlalchemy import delete # Import delete statement
|
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
|
from typing import Optional
|
|
|
|
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
|
|
from app.core.exceptions import (
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
InviteOperationError # Add new specific exception
|
|
)
|
|
|
|
# Invite codes should be reasonably unique, but handle potential collision
|
|
MAX_CODE_GENERATION_ATTEMPTS = 5
|
|
|
|
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]:
|
|
"""Creates a new invite code for a group."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
|
|
|
potential_code = None
|
|
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
|
|
potential_code = secrets.token_urlsafe(16)
|
|
# Check if an *active* invite with this code already exists (outside main transaction for now)
|
|
# Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError.
|
|
# This check reduces collision chance before attempting transaction.
|
|
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
|
existing_result = await db.execute(existing_check_stmt)
|
|
if existing_result.scalar_one_or_none() is None:
|
|
break # Found a potentially unique code
|
|
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
|
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
|
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
# Final check within transaction to be absolutely sure before insert
|
|
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
|
final_check_result = await db.execute(final_check_stmt)
|
|
if final_check_result.scalar_one_or_none() is not None:
|
|
# Extremely unlikely if previous check passed, but handles race condition
|
|
await transaction.rollback()
|
|
raise InviteOperationError("Invite code collision detected during transaction.")
|
|
|
|
db_invite = InviteModel(
|
|
code=potential_code,
|
|
group_id=group_id,
|
|
created_by_id=creator_id,
|
|
expires_at=expires_at,
|
|
is_active=True
|
|
)
|
|
db.add(db_invite)
|
|
await db.flush() # Assigns ID
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(InviteModel)
|
|
.where(InviteModel.id == db_invite.id)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_invite = result.scalar_one_or_none()
|
|
|
|
if loaded_invite is None:
|
|
await transaction.rollback()
|
|
raise InviteOperationError("Failed to load invite after creation.")
|
|
|
|
await transaction.commit()
|
|
return loaded_invite
|
|
except IntegrityError as e: # Catch if DB unique constraint on code is violated
|
|
# Rollback handled by context manager
|
|
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}")
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
|
|
|
|
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
|
"""Gets an active and non-expired invite by its code."""
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
stmt = (
|
|
select(InviteModel).where(
|
|
InviteModel.code == code,
|
|
InviteModel.is_active == True,
|
|
InviteModel.expires_at > now
|
|
)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
|
|
|
|
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
|
|
"""Marks an invite as inactive (used) and reloads with relationships."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
invite.is_active = False
|
|
db.add(invite) # Add to session to track change
|
|
await db.flush() # Persist is_active change
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(InviteModel)
|
|
.where(InviteModel.id == invite.id)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
updated_invite = result.scalar_one_or_none()
|
|
|
|
if updated_invite is None: # Should not happen as invite is passed in
|
|
await transaction.rollback()
|
|
raise InviteOperationError("Failed to load invite after deactivation.")
|
|
|
|
await transaction.commit()
|
|
return updated_invite
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
|
|
|
# Ensure InviteOperationError is defined in app.core.exceptions
|
|
# Example: class InviteOperationError(AppException): pass
|
|
|
|
# Optional: Function to periodically delete old, inactive invites
|
|
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... |