File size: 5,551 Bytes
e8051be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import requests
import google.generativeai as genai
from PIL import Image
from io import BytesIO
from typing import List, Union
import logging

from dotenv import load_dotenv

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()

# Configure Gemini API for image processing
genai.configure(api_key=os.getenv("GEMINI_API_KEY_IMAGE"))

def load_image(image_source: str) -> Image.Image:
    """Load image from a URL or local path."""
    try:
        if image_source.startswith("http://") or image_source.startswith("https://"):
            logger.info(f"Loading image from URL: {image_source}")
            response = requests.get(image_source, timeout=30)
            response.raise_for_status()
            return Image.open(BytesIO(response.content)).convert("RGB")
        elif os.path.isfile(image_source):
            logger.info(f"Loading image from file: {image_source}")
            return Image.open(image_source).convert("RGB")
        else:
            raise ValueError("Invalid image source: must be a valid URL or file path")
    except Exception as e:
        logger.error(f"Failed to load image from {image_source}: {e}")
        raise RuntimeError(f"Failed to load image: {e}")

def get_answer_for_image(image_source: str, questions: List[str], retries: int = 3) -> List[str]:
    """Ask questions about an image using Gemini Vision model."""
    try:
        logger.info(f"Processing image with {len(questions)} questions")
        image = load_image(image_source)
        
        prompt = """

Answer the following questions about the image. Give the answers in the same order as the questions. 

Answers should be descriptive and detailed. Give one answer per line with numbering as "1. 2. 3. ..".



Example answer format:

1. Answer 1, Explanation

2. Answer 2, Explanation

3. Answer 3, Explanation



Questions: 

"""
        prompt += "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions))

        model = genai.GenerativeModel("gemini-1.5-flash")
        
        for attempt in range(retries):
            try:
                logger.info(f"Attempt {attempt + 1} of {retries} to get response from Gemini")
                response = model.generate_content(
                    [prompt, image],
                    generation_config=genai.types.GenerationConfig(
                        temperature=0.4,
                        max_output_tokens=2048
                    )
                )
                raw_text = response.text.strip()
                logger.info(f"Received response from Gemini: {len(raw_text)} characters")
                
                answers = extract_ordered_answers(raw_text, len(questions))
                if len(answers) == len(questions):
                    logger.info(f"Successfully extracted {len(answers)} answers")
                    return answers
                else:
                    logger.warning(f"Expected {len(questions)} answers, got {len(answers)}")
                    
            except Exception as e:
                logger.error(f"Attempt {attempt + 1} failed: {e}")
                if attempt == retries - 1:
                    raise RuntimeError(f"Failed after {retries} attempts: {e}")

        raise RuntimeError("Failed to get valid response from Gemini.")
        
    except Exception as e:
        logger.error(f"Error in get_answer_for_image: {e}")
        raise

def extract_ordered_answers(raw_text: str, expected_count: int) -> List[str]:
    """Parse the raw Gemini output into a clean list of answers."""
    import re
    
    logger.debug(f"Extracting {expected_count} answers from raw text")
    lines = raw_text.splitlines()
    answers = []
    
    for line in lines:
        # Match numbered lines: "1. Answer", "1) Answer", "1 - Answer", etc.
        match = re.match(r"^\s*(\d+)[\).\s-]*\s*(.+)", line)
        if match:
            answer_text = match.group(2).strip()
            if answer_text:  # Only add non-empty answers
                answers.append(answer_text)
    
    # Fallback: if numbering failed, use plain lines
    if len(answers) < expected_count:
        logger.warning("Numbered extraction failed, using fallback method")
        answers = [line.strip() for line in lines if line.strip()]
    
    # Return exactly the expected number of answers
    result = answers[:expected_count]
    
    # If we still don't have enough answers, pad with error messages
    while len(result) < expected_count:
        result.append("Unable to extract answer from image")
    
    logger.info(f"Extracted {len(result)} answers")
    return result

def process_image_query(image_path: str, query: str) -> str:
    """Process a single query about an image."""
    try:
        questions = [query]
        answers = get_answer_for_image(image_path, questions)
        return answers[0] if answers else "Unable to process image query"
    except Exception as e:
        logger.error(f"Error processing image query: {e}")
        return f"Error processing image: {str(e)}"

def process_multiple_image_queries(image_path: str, queries: List[str]) -> List[str]:
    """Process multiple queries about an image."""
    try:
        return get_answer_for_image(image_path, queries)
    except Exception as e:
        logger.error(f"Error processing multiple image queries: {e}")
        return [f"Error processing image: {str(e)}"] * len(queries)