# app/crud/group.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload # For eager loading members from typing import Optional, List from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.schemas.group import GroupCreate from app.models import UserRoleEnum # Import enum # --- Keep existing functions: get_user_by_email, create_user --- # (These are actually user CRUD, should ideally be in user.py, but keep for now if working) from app.core.security import hash_password from app.schemas.user import UserCreate # Assuming create_user uses this async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: result = await db.execute(select(UserModel).filter(UserModel.email == email)) return result.scalars().first() async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: _hashed_password = hash_password(user_in.password) db_user = UserModel( email=user_in.email, password_hash=_hashed_password, # Use correct keyword argument name=user_in.name ) db.add(db_user) await db.commit() await db.refresh(db_user) return db_user # --- End User CRUD --- # --- Group CRUD --- async def create_group(db: AsyncSession, group_in: GroupCreate, creator_id: int) -> GroupModel: """Creates a group and adds the creator as the owner.""" db_group = GroupModel(name=group_in.name, created_by_id=creator_id) db.add(db_group) await db.flush() # Flush to get the db_group.id for the UserGroup entry # Add creator as owner db_user_group = UserGroupModel( user_id=creator_id, group_id=db_group.id, role=UserRoleEnum.owner # Use the Enum member ) db.add(db_user_group) await db.commit() await db.refresh(db_group) return db_group async def get_user_groups(db: AsyncSession, user_id: int) -> List[GroupModel]: """Gets all groups a user is a member of.""" result = await db.execute( select(GroupModel) .join(UserGroupModel) .where(UserGroupModel.user_id == user_id) .options(selectinload(GroupModel.member_associations)) # Optional: preload associations if needed often ) return result.scalars().all() async def get_group_by_id(db: AsyncSession, group_id: int) -> Optional[GroupModel]: """Gets a single group by its ID, optionally loading members.""" # Use selectinload to eager load members and their user details result = await db.execute( select(GroupModel) .where(GroupModel.id == group_id) .options( selectinload(GroupModel.member_associations).selectinload(UserGroupModel.user) ) ) return result.scalars().first() async def is_user_member(db: AsyncSession, group_id: int, user_id: int) -> bool: """Checks if a user is a member of a specific group.""" result = await db.execute( select(UserGroupModel.id) # Select just one column for existence check .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .limit(1) ) return result.scalar_one_or_none() is not None async def get_user_role_in_group(db: AsyncSession, group_id: int, user_id: int) -> Optional[UserRoleEnum]: """Gets the role of a user in a specific group.""" result = await db.execute( select(UserGroupModel.role) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) ) role = result.scalar_one_or_none() return role # Will be None if not a member, or the UserRoleEnum value async def add_user_to_group(db: AsyncSession, group_id: int, user_id: int, role: UserRoleEnum = UserRoleEnum.member) -> Optional[UserGroupModel]: """Adds a user to a group if they aren't already a member.""" # Check if already exists existing = await db.execute( select(UserGroupModel).where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) ) if existing.scalar_one_or_none(): return None # Indicate user already in group db_user_group = UserGroupModel(user_id=user_id, group_id=group_id, role=role) db.add(db_user_group) await db.commit() await db.refresh(db_user_group) return db_user_group async def remove_user_from_group(db: AsyncSession, group_id: int, user_id: int) -> bool: """Removes a user from a group.""" result = await db.execute( delete(UserGroupModel) .where(UserGroupModel.group_id == group_id, UserGroupModel.user_id == user_id) .returning(UserGroupModel.id) # Optional: check if a row was actually deleted ) await db.commit() return result.scalar_one_or_none() is not None # True if deletion happened async def get_group_member_count(db: AsyncSession, group_id: int) -> int: """Counts the number of members in a group.""" result = await db.execute( select(func.count(UserGroupModel.id)).where(UserGroupModel.group_id == group_id) ) return result.scalar_one()