195 lines
8.0 KiB
Python
195 lines
8.0 KiB
Python
# app/core/gemini.py
|
|
import logging
|
|
from typing import List
|
|
import google.generativeai as genai
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
|
from google.api_core import exceptions as google_exceptions
|
|
from app.config import settings
|
|
from app.core.exceptions import (
|
|
OCRServiceUnavailableError,
|
|
OCRServiceConfigError,
|
|
OCRUnexpectedError,
|
|
OCRQuotaExceededError
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Global variable to hold the initialized model client ---
|
|
gemini_flash_client = None
|
|
gemini_initialization_error = None # Store potential init error
|
|
|
|
# --- Configure and Initialize ---
|
|
try:
|
|
if settings.GEMINI_API_KEY:
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
# Initialize the specific model we want to use
|
|
gemini_flash_client = genai.GenerativeModel(
|
|
model_name=settings.GEMINI_MODEL_NAME,
|
|
# Safety settings from config
|
|
safety_settings={
|
|
getattr(HarmCategory, category): getattr(HarmBlockThreshold, threshold)
|
|
for category, threshold in settings.GEMINI_SAFETY_SETTINGS.items()
|
|
},
|
|
# Generation config from settings
|
|
generation_config=genai.types.GenerationConfig(
|
|
**settings.GEMINI_GENERATION_CONFIG
|
|
)
|
|
)
|
|
logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.")
|
|
else:
|
|
# Store error if API key is missing
|
|
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
|
logger.error(gemini_initialization_error)
|
|
|
|
except Exception as e:
|
|
# Catch any other unexpected errors during initialization
|
|
gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}"
|
|
logger.exception(gemini_initialization_error) # Log full traceback
|
|
gemini_flash_client = None # Ensure client is None on error
|
|
|
|
|
|
# --- Function to get the client (optional, allows checking error) ---
|
|
def get_gemini_client():
|
|
"""
|
|
Returns the initialized Gemini client instance.
|
|
Raises an exception if initialization failed.
|
|
"""
|
|
if gemini_initialization_error:
|
|
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
|
|
if gemini_flash_client is None:
|
|
# This case should ideally be covered by the check above, but as a safeguard:
|
|
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
|
|
return gemini_flash_client
|
|
|
|
# Define the prompt as a constant
|
|
OCR_ITEM_EXTRACTION_PROMPT = """
|
|
Extract the shopping list items from this image.
|
|
List each distinct item on a new line.
|
|
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
|
|
Focus only on the names of the products or items to be purchased.
|
|
If the image does not appear to be a shopping list or receipt, state that clearly.
|
|
Example output for a grocery list:
|
|
Milk
|
|
Eggs
|
|
Bread
|
|
Apples
|
|
Organic Bananas
|
|
"""
|
|
|
|
async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]:
|
|
"""
|
|
Uses Gemini Flash to extract shopping list items from image bytes.
|
|
|
|
Args:
|
|
image_bytes: The image content as bytes.
|
|
mime_type: The MIME type of the image (e.g., "image/jpeg", "image/png", "image/webp").
|
|
|
|
Returns:
|
|
A list of extracted item strings.
|
|
|
|
Raises:
|
|
RuntimeError: If the Gemini client is not initialized.
|
|
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
|
|
ValueError: If the response is blocked or contains no usable text.
|
|
"""
|
|
client = get_gemini_client() # Raises RuntimeError if not initialized
|
|
|
|
# Prepare image part for multimodal input
|
|
image_part = {
|
|
"mime_type": mime_type,
|
|
"data": image_bytes
|
|
}
|
|
|
|
# Prepare the full prompt content
|
|
prompt_parts = [
|
|
settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
|
image_part # Then the image
|
|
]
|
|
|
|
logger.info("Sending image to Gemini for item extraction...")
|
|
try:
|
|
# Make the API call
|
|
# Use generate_content_async for async FastAPI
|
|
response = await client.generate_content_async(prompt_parts)
|
|
|
|
# --- Process the response ---
|
|
# Check for safety blocks or lack of content
|
|
if not response.candidates or not response.candidates[0].content.parts:
|
|
logger.warning("Gemini response blocked or empty.", extra={"response": response})
|
|
# Check finish_reason if available
|
|
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
|
|
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
|
|
if finish_reason == 'SAFETY':
|
|
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
|
else:
|
|
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
|
|
|
|
# Extract text - assumes the first part of the first candidate is the text response
|
|
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
|
|
logger.info("Received raw text from Gemini.")
|
|
# logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text
|
|
|
|
# Parse the text response
|
|
items = []
|
|
for line in raw_text.splitlines(): # Split by newline
|
|
cleaned_line = line.strip() # Remove leading/trailing whitespace
|
|
# Basic filtering: ignore empty lines and potential non-item lines
|
|
if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too?
|
|
# Add more sophisticated filtering if needed (e.g., regex, keyword check)
|
|
items.append(cleaned_line)
|
|
|
|
logger.info(f"Extracted {len(items)} potential items.")
|
|
return items
|
|
|
|
except google_exceptions.GoogleAPIError as e:
|
|
logger.error(f"Gemini API Error: {e}", exc_info=True)
|
|
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
|
|
raise e
|
|
except Exception as e:
|
|
# Catch other unexpected errors during generation or processing
|
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
|
# Wrap in a generic ValueError or re-raise
|
|
raise ValueError(f"Failed to process image with Gemini: {e}") from e
|
|
|
|
class GeminiOCRService:
|
|
def __init__(self):
|
|
try:
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
self.model = genai.GenerativeModel(settings.GEMINI_MODEL_NAME)
|
|
self.model.safety_settings = settings.GEMINI_SAFETY_SETTINGS
|
|
self.model.generation_config = settings.GEMINI_GENERATION_CONFIG
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize Gemini client: {e}")
|
|
raise OCRServiceConfigError()
|
|
|
|
async def extract_items(self, image_data: bytes) -> List[str]:
|
|
"""
|
|
Extract shopping list items from an image using Gemini Vision.
|
|
"""
|
|
try:
|
|
# Create image part
|
|
image_parts = [{"mime_type": "image/jpeg", "data": image_data}]
|
|
|
|
# Generate content
|
|
response = await self.model.generate_content_async(
|
|
contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts]
|
|
)
|
|
|
|
# Process response
|
|
if not response.text:
|
|
raise OCRUnexpectedError()
|
|
|
|
# Split response into lines and clean up
|
|
items = [
|
|
item.strip()
|
|
for item in response.text.split("\n")
|
|
if item.strip() and not item.strip().startswith("Example")
|
|
]
|
|
|
|
return items
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during OCR extraction: {e}")
|
|
if "quota" in str(e).lower():
|
|
raise OCRQuotaExceededError()
|
|
raise OCRServiceUnavailableError() |