rag-bajaj / LLM /image_answerer.py
quantumbit's picture
Upload 39 files
e8051be verified
import os
import requests
import google.generativeai as genai
from PIL import Image
from io import BytesIO
from typing import List, Union
import logging
from dotenv import load_dotenv
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
# Configure Gemini API for image processing
genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE"))
def load_image(image_source: str) -> Image.Image:
"""Load image from a URL or local path."""
try:
if image_source.startswith("http://") or image_source.startswith("https://"):
logger.info(f"Loading image from URL: {image_source}")
response = requests.get(image_source, timeout=30)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
elif os.path.isfile(image_source):
logger.info(f"Loading image from file: {image_source}")
return Image.open(image_source).convert("RGB")
else:
raise ValueError("Invalid image source: must be a valid URL or file path")
except Exception as e:
logger.error(f"Failed to load image from {image_source}: {e}")
raise RuntimeError(f"Failed to load image: {e}")
def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]:
"""Ask questions about an image using Gemini Vision model."""
try:
logger.info(f"Processing image with {len(questions)} questions")
image = load_image(image_source)
prompt = """
Answer the following questions about the image. Give the answers in the same order as the questions.
Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..".
Example answer format:
1. Answer 1, Explanation
2. Answer 2, Explanation
3. Answer 3, Explanation
Questions:
"""
prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions))
model = genai.GenerativeModel("gemini-1.5-flash")
for attempt in range(retries):
try:
logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini")
response = model.generate_content(
[prompt, image],
generation_config=genai.types.GenerationConfig(
temperature=0.4,
max_output_tokens=2048
)
)
raw_text = response.text.strip()
logger.info(f"Received response from Gemini: {len(raw_text)} characters")
answers = extract_ordered_answers(raw_text, len(questions))
if len(answers) == len(questions):
logger.info(f"Successfully extracted {len(answers)} answers")
return answers
else:
logger.warning(f"Expected {len(questions)} answers, got {len(answers)}")
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed: {e}")
if attempt == retries - 1:
raise RuntimeError(f"Failed after {retries} attempts: {e}")
raise RuntimeError("Failed to get valid response from Gemini.")
except Exception as e:
logger.error(f"Error in get_answer_for_image: {e}")
raise
def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]:
"""Parse the raw Gemini output into a clean list of answers."""
import re
logger.debug(f"Extracting {expected_count} answers from raw text")
lines = raw_text.splitlines()
answers = []
for line in lines:
# Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc.
match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line)
if match:
answer_text = match.group(2).strip()
if answer_text: # Only add non-empty answers
answers.append(answer_text)
# Fallback: if numbering failed, use plain lines
if len(answers) < expected_count:
logger.warning("Numbered extraction failed, using fallback method")
answers = [line.strip() for line in lines if line.strip()]
# Return exactly the expected number of answers
result = answers[:expected_count]
# If we still don't have enough answers, pad with error messages
while len(result) < expected_count:
result.append("Unable to extract answer from image")
logger.info(f"Extracted {len(result)} answers")
return result
def process_image_query(image_path: str, query: str) -> str:
"""Process a single query about an image."""
try:
questions = [query]
answers = get_answer_for_image(image_path, questions)
return answers[0] if answers else "Unable to process image query"
except Exception as e:
logger.error(f"Error processing image query: {e}")
return f"Error processing image: {str(e)}"
def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]:
"""Process multiple queries about an image."""
try:
return get_answer_for_image(image_path, queries)
except Exception as e:
logger.error(f"Error processing multiple image queries: {e}")
return [f"Error processing image: {str(e)}"] * len(queries)