Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import google.generativeai as genai | |
from PIL import Image | |
from io import BytesIO | |
from typing import List, Union | |
import logging | |
from dotenv import load_dotenv | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
load_dotenv() | |
# Configure Gemini API for image processing | |
genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE")) | |
def load_image(image_source: str) -> Image.Image: | |
"""Load image from a URL or local path.""" | |
try: | |
if image_source.startswith("http://") or image_source.startswith("https://"): | |
logger.info(f"Loading image from URL: {image_source}") | |
response = requests.get(image_source, timeout=30) | |
response.raise_for_status() | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
elif os.path.isfile(image_source): | |
logger.info(f"Loading image from file: {image_source}") | |
return Image.open(image_source).convert("RGB") | |
else: | |
raise ValueError("Invalid image source: must be a valid URL or file path") | |
except Exception as e: | |
logger.error(f"Failed to load image from {image_source}: {e}") | |
raise RuntimeError(f"Failed to load image: {e}") | |
def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]: | |
"""Ask questions about an image using Gemini Vision model.""" | |
try: | |
logger.info(f"Processing image with {len(questions)} questions") | |
image = load_image(image_source) | |
prompt = """ | |
Answer the following questions about the image. Give the answers in the same order as the questions. | |
Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..". | |
Example answer format: | |
1. Answer 1, Explanation | |
2. Answer 2, Explanation | |
3. Answer 3, Explanation | |
Questions: | |
""" | |
prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions)) | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
for attempt in range(retries): | |
try: | |
logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini") | |
response = model.generate_content( | |
[prompt, image], | |
generation_config=genai.types.GenerationConfig( | |
temperature=0.4, | |
max_output_tokens=2048 | |
) | |
) | |
raw_text = response.text.strip() | |
logger.info(f"Received response from Gemini: {len(raw_text)} characters") | |
answers = extract_ordered_answers(raw_text, len(questions)) | |
if len(answers) == len(questions): | |
logger.info(f"Successfully extracted {len(answers)} answers") | |
return answers | |
else: | |
logger.warning(f"Expected {len(questions)} answers, got {len(answers)}") | |
except Exception as e: | |
logger.error(f"Attempt {attempt + 1} failed: {e}") | |
if attempt == retries - 1: | |
raise RuntimeError(f"Failed after {retries} attempts: {e}") | |
raise RuntimeError("Failed to get valid response from Gemini.") | |
except Exception as e: | |
logger.error(f"Error in get_answer_for_image: {e}") | |
raise | |
def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]: | |
"""Parse the raw Gemini output into a clean list of answers.""" | |
import re | |
logger.debug(f"Extracting {expected_count} answers from raw text") | |
lines = raw_text.splitlines() | |
answers = [] | |
for line in lines: | |
# Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc. | |
match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line) | |
if match: | |
answer_text = match.group(2).strip() | |
if answer_text: # Only add non-empty answers | |
answers.append(answer_text) | |
# Fallback: if numbering failed, use plain lines | |
if len(answers) < expected_count: | |
logger.warning("Numbered extraction failed, using fallback method") | |
answers = [line.strip() for line in lines if line.strip()] | |
# Return exactly the expected number of answers | |
result = answers[:expected_count] | |
# If we still don't have enough answers, pad with error messages | |
while len(result) < expected_count: | |
result.append("Unable to extract answer from image") | |
logger.info(f"Extracted {len(result)} answers") | |
return result | |
def process_image_query(image_path: str, query: str) -> str: | |
"""Process a single query about an image.""" | |
try: | |
questions = [query] | |
answers = get_answer_for_image(image_path, questions) | |
return answers[0] if answers else "Unable to process image query" | |
except Exception as e: | |
logger.error(f"Error processing image query: {e}") | |
return f"Error processing image: {str(e)}" | |
def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]: | |
"""Process multiple queries about an image.""" | |
try: | |
return get_answer_for_image(image_path, queries) | |
except Exception as e: | |
logger.error(f"Error processing multiple image queries: {e}") | |
return [f"Error processing image: {str(e)}"] * len(queries) |