229 lines
9.7 KiB
Python
229 lines
9.7 KiB
Python
# app/core/gemini.py
|
|
import logging
|
|
from typing import List
|
|
import google.generativeai as genai
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
|
from google.api_core import exceptions as google_exceptions
|
|
from app.config import settings
|
|
from app.core.exceptions import (
|
|
OCRServiceUnavailableError,
|
|
OCRServiceConfigError,
|
|
OCRUnexpectedError,
|
|
OCRQuotaExceededError,
|
|
OCRProcessingError
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Global variable to hold the initialized model client ---
|
|
gemini_flash_client = None
|
|
gemini_initialization_error = None # Store potential init error
|
|
|
|
# --- Configure and Initialize ---
|
|
try:
|
|
if settings.GEMINI_API_KEY:
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
# Initialize the specific model we want to use
|
|
gemini_flash_client = genai.GenerativeModel(
|
|
model_name=settings.GEMINI_MODEL_NAME,
|
|
generation_config=genai.types.GenerationConfig(
|
|
**settings.GEMINI_GENERATION_CONFIG
|
|
)
|
|
)
|
|
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
|
|
else:
|
|
# Store error if API key is missing
|
|
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
|
logger.error(gemini_initialization_error)
|
|
|
|
except Exception as e:
|
|
# Catch any other unexpected errors during initialization
|
|
gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}"
|
|
logger.exception(gemini_initialization_error) # Log full traceback
|
|
gemini_flash_client = None # Ensure client is None on error
|
|
|
|
|
|
# --- Function to get the client (optional, allows checking error) ---
|
|
def get_gemini_client():
|
|
"""
|
|
Returns the initialized Gemini client instance.
|
|
Raises an exception if initialization failed.
|
|
"""
|
|
if gemini_initialization_error:
|
|
raise OCRServiceConfigError()
|
|
if gemini_flash_client is None:
|
|
# This case should ideally be covered by the check above, but as a safeguard:
|
|
raise OCRServiceConfigError()
|
|
return gemini_flash_client
|
|
|
|
# Define the prompt as a constant
|
|
OCR_ITEM_EXTRACTION_PROMPT = """
|
|
Extract the shopping list items from this image.
|
|
List each distinct item on a new line.
|
|
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
|
|
Focus only on the names of the products or items to be purchased.
|
|
If the image does not appear to be a shopping list or receipt, state that clearly.
|
|
Example output for a grocery list:
|
|
Milk
|
|
Eggs
|
|
Bread
|
|
Apples
|
|
Organic Bananas
|
|
"""
|
|
|
|
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
|
|
"""
|
|
Uses Gemini Flash to extract shopping list items from image bytes.
|
|
|
|
Args:
|
|
image_bytes: The image content as bytes.
|
|
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
|
|
|
|
Returns:
|
|
A list of extracted item strings.
|
|
|
|
Raises:
|
|
OCRServiceConfigError: If the Gemini client is not initialized.
|
|
OCRQuotaExceededError: If API quota is exceeded.
|
|
OCRServiceUnavailableError: For general API call errors.
|
|
OCRProcessingError: If the response is blocked or contains no usable text.
|
|
OCRUnexpectedError: For any other unexpected errors.
|
|
"""
|
|
try:
|
|
client = get_gemini_client() # Raises OCRServiceConfigError if not initialized
|
|
|
|
# Prepare image part for multimodal input
|
|
image_part = {
|
|
"mime_type": mime_type,
|
|
"data": image_bytes
|
|
}
|
|
|
|
# Prepare the full prompt content
|
|
prompt_parts = [
|
|
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
|
image_part # Then the image
|
|
]
|
|
|
|
logger.info("Sending image to Gemini for item extraction...")
|
|
|
|
# Make the API call
|
|
# Use generate_content_async for async FastAPI
|
|
response = await client.generate_content_async(prompt_parts)
|
|
|
|
# --- Process the response ---
|
|
# Check for safety blocks or lack of content
|
|
if not response.candidates or not response.candidates[0].content.parts:
|
|
logger.warning("Gemini response blocked or empty.", extra={"response": response})
|
|
# Check finish_reason if available
|
|
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
|
|
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
|
|
if finish_reason == 'SAFETY':
|
|
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
|
else:
|
|
raise OCRUnexpectedError()
|
|
|
|
# Extract text - assumes the first part of the first candidate is the text response
|
|
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
|
|
logger.info("Received raw text from Gemini.")
|
|
# logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text
|
|
|
|
# Parse the text response
|
|
items = []
|
|
for line in raw_text.splitlines(): # Split by newline
|
|
cleaned_line = line.strip() # Remove leading/trailing whitespace
|
|
# Basic filtering: ignore empty lines and potential non-item lines
|
|
if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too?
|
|
# Add more sophisticated filtering if needed (e.g., regex, keyword check)
|
|
items.append(cleaned_line)
|
|
|
|
logger.info(f"Extracted {len(items)} potential items.")
|
|
return items
|
|
|
|
except google_exceptions.GoogleAPIError as e:
|
|
logger.error(f"Gemini API Error: {e}", exc_info=True)
|
|
if "quota" in str(e).lower():
|
|
raise OCRQuotaExceededError()
|
|
raise OCRServiceUnavailableError()
|
|
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
|
|
# Re-raise specific OCR exceptions
|
|
raise
|
|
except Exception as e:
|
|
# Catch other unexpected errors during generation or processing
|
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
|
# Wrap in a custom exception
|
|
raise OCRUnexpectedError()
|
|
|
|
class GeminiOCRService:
|
|
def __init__(self):
|
|
try:
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
self.model = genai.GenerativeModel(
|
|
model_name=settings.GEMINI_MODEL_NAME,
|
|
generation_config=genai.types.GenerationConfig(
|
|
**settings.GEMINI_GENERATION_CONFIG
|
|
)
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Gemini client: {e}")
|
|
raise OCRServiceConfigError()
|
|
|
|
async def extract_items(self, image_data: bytes, mime_type: str = "image/jpeg") -> List[str]:
|
|
"""
|
|
Extract shopping list items from an image using Gemini Vision.
|
|
|
|
Args:
|
|
image_data: The image content as bytes.
|
|
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
|
|
|
|
Returns:
|
|
A list of extracted item strings.
|
|
|
|
Raises:
|
|
OCRServiceConfigError: If the Gemini client is not initialized.
|
|
OCRQuotaExceededError: If API quota is exceeded.
|
|
OCRServiceUnavailableError: For general API call errors.
|
|
OCRProcessingError: If the response is blocked or contains no usable text.
|
|
OCRUnexpectedError: For any other unexpected errors.
|
|
"""
|
|
try:
|
|
# Create image part
|
|
image_parts = [{"mime_type": mime_type, "data": image_data}]
|
|
|
|
# Generate content
|
|
response = await self.model.generate_content_async(
|
|
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
|
|
)
|
|
|
|
# Process response
|
|
if not response.text:
|
|
logger.warning("Gemini response is empty")
|
|
raise OCRUnexpectedError()
|
|
|
|
# Check for safety blocks
|
|
if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'):
|
|
finish_reason = response.candidates[0].finish_reason
|
|
if finish_reason == 'SAFETY':
|
|
safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A'
|
|
raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
|
|
|
# Split response into lines and clean up
|
|
items = []
|
|
for line in response.text.splitlines():
|
|
cleaned_line = line.strip()
|
|
if cleaned_line and len(cleaned_line) > 1 and not cleaned_line.startswith("Example"):
|
|
items.append(cleaned_line)
|
|
|
|
logger.info(f"Extracted {len(items)} potential items.")
|
|
return items
|
|
|
|
except google_exceptions.GoogleAPIError as e:
|
|
logger.error(f"Error during OCR extraction: {e}")
|
|
if "quota" in str(e).lower():
|
|
raise OCRQuotaExceededError()
|
|
raise OCRServiceUnavailableError()
|
|
except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError):
|
|
# Re-raise specific OCR exceptions
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
|
raise OCRUnexpectedError() |