189 lines
8.5 KiB
Python
189 lines
8.5 KiB
Python
# app/crud/invite.py
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload # Ensure selectinload is imported
|
|
from sqlalchemy import delete # Import delete statement
|
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
|
|
from typing import Optional
|
|
|
|
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload
|
|
from app.core.exceptions import (
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
InviteOperationError # Add new specific exception
|
|
)
|
|
|
|
# Invite codes should be reasonably unique, but handle potential collision
|
|
MAX_CODE_GENERATION_ATTEMPTS = 5
|
|
|
|
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
|
|
"""Deactivates all currently active invite codes for a specific group."""
|
|
try:
|
|
stmt = (
|
|
select(InviteModel)
|
|
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
|
|
)
|
|
result = await db.execute(stmt)
|
|
active_invites = result.scalars().all()
|
|
|
|
if not active_invites:
|
|
return # No active invites to deactivate
|
|
|
|
for invite in active_invites:
|
|
invite.is_active = False
|
|
db.add(invite)
|
|
|
|
# await db.flush() # Removed: Rely on caller to flush/commit
|
|
# No explicit commit here, assuming it's part of a larger transaction or caller handles commit.
|
|
except OperationalError as e:
|
|
# It's better to let the caller handle rollback or commit based on overall operation success
|
|
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
|
|
|
|
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years
|
|
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
|
|
|
|
# Deactivate existing active invites for this group
|
|
await deactivate_all_active_invites_for_group(db, group_id)
|
|
|
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
|
|
|
|
potential_code = None
|
|
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
|
|
potential_code = secrets.token_urlsafe(16)
|
|
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
|
existing_result = await db.execute(existing_check_stmt)
|
|
if existing_result.scalar_one_or_none() is None:
|
|
break
|
|
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
|
|
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
|
|
|
|
# Removed explicit transaction block here, rely on FastAPI's request-level transaction.
|
|
# Final check for code collision (less critical now without explicit nested transaction rollback on collision)
|
|
# but still good to prevent duplicate active codes if possible, though the deactivate step helps.
|
|
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
|
|
final_check_result = await db.execute(final_check_stmt)
|
|
if final_check_result.scalar_one_or_none() is not None:
|
|
# This is now more of a rare edge case if deactivate worked and code generation is diverse.
|
|
# Depending on strictness, could raise an error or just log and proceed,
|
|
# relying on the previous deactivation to ensure only one is active.
|
|
# For now, let's raise to be safe, as it implies a very quick collision.
|
|
raise InviteOperationError("Invite code collision detected just before creation attempt.")
|
|
|
|
db_invite = InviteModel(
|
|
code=potential_code,
|
|
group_id=group_id,
|
|
created_by_id=creator_id,
|
|
expires_at=expires_at,
|
|
is_active=True
|
|
)
|
|
db.add(db_invite)
|
|
await db.flush() # Flush to get ID for re-fetch and ensure it's in session before potential re-fetch.
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(InviteModel)
|
|
.where(InviteModel.id == db_invite.id)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_invite = result.scalar_one_or_none()
|
|
|
|
if loaded_invite is None:
|
|
# This would be an issue, implies flush didn't work or ID was wrong.
|
|
# The main transaction will rollback if this exception is raised.
|
|
raise InviteOperationError("Failed to load invite after creation and flush.")
|
|
|
|
return loaded_invite
|
|
# No explicit commit here, FastAPI handles it for the request.
|
|
|
|
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
|
|
"""Gets the currently active and non-expired invite for a specific group."""
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
stmt = (
|
|
select(InviteModel).where(
|
|
InviteModel.group_id == group_id,
|
|
InviteModel.is_active == True,
|
|
InviteModel.expires_at > now # Still respect expiry, even if very long
|
|
)
|
|
.order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen)
|
|
.limit(1)
|
|
.options(
|
|
selectinload(InviteModel.group), # Eager load group
|
|
selectinload(InviteModel.creator) # Eager load creator
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
|
|
|
|
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
|
|
"""Gets an active and non-expired invite by its code."""
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
stmt = (
|
|
select(InviteModel).where(
|
|
InviteModel.code == code,
|
|
InviteModel.is_active == True,
|
|
InviteModel.expires_at > now
|
|
)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
|
|
|
|
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
|
|
"""Marks an invite as inactive (used) and reloads with relationships."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
invite.is_active = False
|
|
db.add(invite) # Add to session to track change
|
|
await db.flush() # Persist is_active change
|
|
|
|
# Re-fetch with relationships
|
|
stmt = (
|
|
select(InviteModel)
|
|
.where(InviteModel.id == invite.id)
|
|
.options(
|
|
selectinload(InviteModel.group),
|
|
selectinload(InviteModel.creator)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
updated_invite = result.scalar_one_or_none()
|
|
|
|
if updated_invite is None: # Should not happen as invite is passed in
|
|
await transaction.rollback()
|
|
raise InviteOperationError("Failed to load invite after deactivation.")
|
|
|
|
await transaction.commit()
|
|
return updated_invite
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")
|
|
|
|
# Ensure InviteOperationError is defined in app.core.exceptions
|
|
# Example: class InviteOperationError(AppException): pass
|
|
|
|
# Optional: Function to periodically delete old, inactive invites
|
|
# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... |