# app/crud/invite.py import logging # Add logging import import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy import delete # Import delete statement from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from typing import Optional from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, InviteOperationError # Add new specific exception ) logger = logging.getLogger(__name__) # Initialize logger # Invite codes should be reasonably unique, but handle potential collision MAX_CODE_GENERATION_ATTEMPTS = 5 async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int): """Deactivates all currently active invite codes for a specific group.""" try: async with db.begin_nested() if db.in_transaction() else db.begin(): stmt = ( select(InviteModel) .where(InviteModel.group_id == group_id, InviteModel.is_active == True) ) result = await db.execute(stmt) active_invites = result.scalars().all() if not active_invites: return # No active invites to deactivate for invite in active_invites: invite.is_active = False db.add(invite) await db.flush() # Flush changes within this transaction block # await db.flush() # Removed: Rely on caller to flush/commit # No explicit commit here, assuming it's part of a larger transaction or caller handles commit. except OperationalError as e: logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}") async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years """Creates a new invite code for a group, deactivating any existing active ones for that group first.""" try: async with db.begin_nested() if db.in_transaction() else db.begin(): # Deactivate existing active invites for this group await deactivate_all_active_invites_for_group(db, group_id) expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) potential_code = None for attempt in range(MAX_CODE_GENERATION_ATTEMPTS): potential_code = secrets.token_urlsafe(16) existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) existing_result = await db.execute(existing_check_stmt) if existing_result.scalar_one_or_none() is None: break if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1: raise InviteOperationError("Failed to generate a unique invite code after several attempts.") final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) final_check_result = await db.execute(final_check_stmt) if final_check_result.scalar_one_or_none() is not None: raise InviteOperationError("Invite code collision detected just before creation attempt.") db_invite = InviteModel( code=potential_code, group_id=group_id, created_by_id=creator_id, expires_at=expires_at, is_active=True ) db.add(db_invite) await db.flush() stmt = ( select(InviteModel) .where(InviteModel.id == db_invite.id) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) loaded_invite = result.scalar_one_or_none() if loaded_invite is None: raise InviteOperationError("Failed to load invite after creation and flush.") return loaded_invite except InviteOperationError: # Already specific, re-raise raise except IntegrityError as e: logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True) raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}") except OperationalError as e: logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}") async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]: """Gets the currently active and non-expired invite for a specific group.""" now = datetime.now(timezone.utc) try: stmt = ( select(InviteModel).where( InviteModel.group_id == group_id, InviteModel.is_active == True, InviteModel.expires_at > now # Still respect expiry, even if very long ) .order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen) .limit(1) .options( selectinload(InviteModel.group), # Eager load group selectinload(InviteModel.creator) # Eager load creator ) ) result = await db.execute(stmt) return result.scalars().first() except OperationalError as e: logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}") except SQLAlchemyError as e: logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True) raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}") async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: """Gets an active and non-expired invite by its code.""" now = datetime.now(timezone.utc) try: stmt = ( select(InviteModel).where( InviteModel.code == code, InviteModel.is_active == True, InviteModel.expires_at > now ) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}") async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: """Marks an invite as inactive (used) and reloads with relationships.""" try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: invite.is_active = False db.add(invite) # Add to session to track change await db.flush() # Persist is_active change # Re-fetch with relationships stmt = ( select(InviteModel) .where(InviteModel.id == invite.id) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) updated_invite = result.scalar_one_or_none() if updated_invite is None: # Should not happen as invite is passed in raise InviteOperationError("Failed to load invite after deactivation.") return updated_invite except OperationalError as e: logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}") # Ensure InviteOperationError is defined in app.core.exceptions # Example: class InviteOperationError(AppException): pass # Optional: Function to periodically delete old, inactive invites # async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...