154 lines
6.8 KiB
Python
154 lines
6.8 KiB
Python
# app/core/gemini.py
|
|
import logging
|
|
from typing import List
|
|
import google.generativeai as genai
|
|
from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings
|
|
from google.api_core import exceptions as google_exceptions
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Global variable to hold the initialized model client ---
|
|
gemini_flash_client = None
|
|
gemini_initialization_error = None # Store potential init error
|
|
|
|
# --- Configure and Initialize ---
|
|
try:
|
|
if settings.GEMINI_API_KEY:
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
# Initialize the specific model we want to use
|
|
gemini_flash_client = genai.GenerativeModel(
|
|
model_name="gemini-2.0-flash",
|
|
# Optional: Add default safety settings
|
|
# Adjust these based on your expected content and risk tolerance
|
|
safety_settings={
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
|
},
|
|
# Optional: Add default generation config (can be overridden per request)
|
|
# generation_config=genai.types.GenerationConfig(
|
|
# # candidate_count=1, # Usually default is 1
|
|
# # stop_sequences=["\n"],
|
|
# # max_output_tokens=2048,
|
|
# # temperature=0.9, # Controls randomness (0=deterministic, >1=more random)
|
|
# # top_p=1,
|
|
# # top_k=1
|
|
# )
|
|
)
|
|
logger.info("Gemini AI client initialized successfully for model 'gemini-1.5-flash-latest'.")
|
|
else:
|
|
# Store error if API key is missing
|
|
gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized."
|
|
logger.error(gemini_initialization_error)
|
|
|
|
except Exception as e:
|
|
# Catch any other unexpected errors during initialization
|
|
gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}"
|
|
logger.exception(gemini_initialization_error) # Log full traceback
|
|
gemini_flash_client = None # Ensure client is None on error
|
|
|
|
|
|
# --- Function to get the client (optional, allows checking error) ---
|
|
def get_gemini_client():
|
|
"""
|
|
Returns the initialized Gemini client instance.
|
|
Raises an exception if initialization failed.
|
|
"""
|
|
if gemini_initialization_error:
|
|
raise RuntimeError(f"Gemini client could not be initialized: {gemini_initialization_error}")
|
|
if gemini_flash_client is None:
|
|
# This case should ideally be covered by the check above, but as a safeguard:
|
|
raise RuntimeError("Gemini client is not available (unknown initialization issue).")
|
|
return gemini_flash_client
|
|
|
|
# Define the prompt as a constant
|
|
OCR_ITEM_EXTRACTION_PROMPT = """
|
|
Extract the shopping list items from this image.
|
|
List each distinct item on a new line.
|
|
Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text.
|
|
Focus only on the names of the products or items to be purchased.
|
|
If the image does not appear to be a shopping list or receipt, state that clearly.
|
|
Example output for a grocery list:
|
|
Milk
|
|
Eggs
|
|
Bread
|
|
Apples
|
|
Organic Bananas
|
|
"""
|
|
|
|
async def extract_items_from_image_gemini(image_bytes: bytes) -> List[str]:
|
|
"""
|
|
Uses Gemini Flash to extract shopping list items from image bytes.
|
|
|
|
Args:
|
|
image_bytes: The image content as bytes.
|
|
|
|
Returns:
|
|
A list of extracted item strings.
|
|
|
|
Raises:
|
|
RuntimeError: If the Gemini client is not initialized.
|
|
google_exceptions.GoogleAPIError: For API call errors (quota, invalid key etc.).
|
|
ValueError: If the response is blocked or contains no usable text.
|
|
"""
|
|
client = get_gemini_client() # Raises RuntimeError if not initialized
|
|
|
|
# Prepare image part for multimodal input
|
|
image_part = {
|
|
"mime_type": "image/jpeg", # Or image/png, image/webp etc. Adjust if needed or detect mime type
|
|
"data": image_bytes
|
|
}
|
|
|
|
# Prepare the full prompt content
|
|
prompt_parts = [
|
|
OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first
|
|
image_part # Then the image
|
|
]
|
|
|
|
logger.info("Sending image to Gemini for item extraction...")
|
|
try:
|
|
# Make the API call
|
|
# Use generate_content_async for async FastAPI
|
|
response = await client.generate_content_async(prompt_parts)
|
|
|
|
# --- Process the response ---
|
|
# Check for safety blocks or lack of content
|
|
if not response.candidates or not response.candidates[0].content.parts:
|
|
logger.warning("Gemini response blocked or empty.", extra={"response": response})
|
|
# Check finish_reason if available
|
|
finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN'
|
|
safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A'
|
|
if finish_reason == 'SAFETY':
|
|
raise ValueError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}")
|
|
else:
|
|
raise ValueError(f"Gemini response was empty or incomplete. Finish Reason: {finish_reason}")
|
|
|
|
# Extract text - assumes the first part of the first candidate is the text response
|
|
raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text
|
|
logger.info("Received raw text from Gemini.")
|
|
# logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text
|
|
|
|
# Parse the text response
|
|
items = []
|
|
for line in raw_text.splitlines(): # Split by newline
|
|
cleaned_line = line.strip() # Remove leading/trailing whitespace
|
|
# Basic filtering: ignore empty lines and potential non-item lines
|
|
if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too?
|
|
# Add more sophisticated filtering if needed (e.g., regex, keyword check)
|
|
items.append(cleaned_line)
|
|
|
|
logger.info(f"Extracted {len(items)} potential items.")
|
|
return items
|
|
|
|
except google_exceptions.GoogleAPIError as e:
|
|
logger.error(f"Gemini API Error: {e}", exc_info=True)
|
|
# Re-raise specific Google API errors for endpoint to handle (e.g., quota)
|
|
raise e
|
|
except Exception as e:
|
|
# Catch other unexpected errors during generation or processing
|
|
logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True)
|
|
# Wrap in a generic ValueError or re-raise
|
|
raise ValueError(f"Failed to process image with Gemini: {e}") from e |