import os import requests import google.generativeai as genai from PIL import Image from io import BytesIO from typing import List, Union import logging from dotenv import load_dotenv # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv() # Configure Gemini API for image processing genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE")) def load_image(image_source: str) -> Image.Image: """Load image from a URL or local path.""" try: if image_source.startswith("http://") or image_source.startswith("https://"): logger.info(f"Loading image from URL: {image_source}") response = requests.get(image_source, timeout=30) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") elif os.path.isfile(image_source): logger.info(f"Loading image from file: {image_source}") return Image.open(image_source).convert("RGB") else: raise ValueError("Invalid image source: must be a valid URL or file path") except Exception as e: logger.error(f"Failed to load image from {image_source}: {e}") raise RuntimeError(f"Failed to load image: {e}") def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]: """Ask questions about an image using Gemini Vision model.""" try: logger.info(f"Processing image with {len(questions)} questions") image = load_image(image_source) prompt = """ Answer the following questions about the image. Give the answers in the same order as the questions. Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..". Example answer format: 1. Answer 1, Explanation 2. Answer 2, Explanation 3. Answer 3, Explanation Questions: """ prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions)) model = genai.GenerativeModel("gemini-1.5-flash") for attempt in range(retries): try: logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini") response = model.generate_content( [prompt, image], generation_config=genai.types.GenerationConfig( temperature=0.4, max_output_tokens=2048 ) ) raw_text = response.text.strip() logger.info(f"Received response from Gemini: {len(raw_text)} characters") answers = extract_ordered_answers(raw_text, len(questions)) if len(answers) == len(questions): logger.info(f"Successfully extracted {len(answers)} answers") return answers else: logger.warning(f"Expected {len(questions)} answers, got {len(answers)}") except Exception as e: logger.error(f"Attempt {attempt + 1} failed: {e}") if attempt == retries - 1: raise RuntimeError(f"Failed after {retries} attempts: {e}") raise RuntimeError("Failed to get valid response from Gemini.") except Exception as e: logger.error(f"Error in get_answer_for_image: {e}") raise def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]: """Parse the raw Gemini output into a clean list of answers.""" import re logger.debug(f"Extracting {expected_count} answers from raw text") lines = raw_text.splitlines() answers = [] for line in lines: # Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc. match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line) if match: answer_text = match.group(2).strip() if answer_text: # Only add non-empty answers answers.append(answer_text) # Fallback: if numbering failed, use plain lines if len(answers) < expected_count: logger.warning("Numbered extraction failed, using fallback method") answers = [line.strip() for line in lines if line.strip()] # Return exactly the expected number of answers result = answers[:expected_count] # If we still don't have enough answers, pad with error messages while len(result) < expected_count: result.append("Unable to extract answer from image") logger.info(f"Extracted {len(result)} answers") return result def process_image_query(image_path: str, query: str) -> str: """Process a single query about an image.""" try: questions = [query] answers = get_answer_for_image(image_path, questions) return answers[0] if answers else "Unable to process image query" except Exception as e: logger.error(f"Error processing image query: {e}") return f"Error processing image: {str(e)}" def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]: """Process multiple queries about an image.""" try: return get_answer_for_image(image_path, queries) except Exception as e: logger.error(f"Error processing multiple image queries: {e}") return [f"Error processing image: {str(e)}"] * len(queries)