File size: 4,881 Bytes
249c00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5e8c
249c00e
 
 
 
 
 
 
 
 
 
49f5e8c
249c00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import json
import time
import os
from google import genai
from google.genai import types
from google.genai import errors
from bioclip import TreeOfLifeClassifier, Rank


PROMPT_RETRYIES = 2
DEFAULT_PROMPT = """
Return bounding boxes and a description for each species in this image.
Ensure you only return valid JSON.
""".strip()

# Initialize classifier outside of functions
classifier = TreeOfLifeClassifier()


def crop_image(image, gemini_bounding_box):
    """
    Crop the image based on the bounding box coordinates.

    :param image: PIL Image object
    :param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000
    :return: Cropped PIL Image
    """
    width, height = image.size
    y_min, x_min, y_max, x_max = gemini_bounding_box

    # Convert normalized coordinates to pixel values
    left = int(x_min / 1000 * width)
    upper = int(y_min / 1000 * height)
    right = int(x_max / 1000 * width)
    lower = int(y_max / 1000 * height)

    # Crop and return the image
    return image.crop((left, upper, right, lower))


def predict_species(img):
    predictions = classifier.predict([img], Rank.SPECIES, k=1)
    return predictions[0]


def make_crops(image, predictions_json_txt):
    """
    Process predictions to crop images based on bounding boxes.

    :param image: PIL Image object
    :param predictions: str of JSON List of prediction dictionaries containing bounding boxes
    :return: List of cropped images
    """
    cropped_images = []
    try:
        predictions_json_txt
        predictions = json.loads(predictions_json_txt)
    except json.JSONDecodeError as e:
        print(str(e))
        return []  # Return empty list if JSON parsing fails

    for prediction in predictions:
        if "box_2d" in prediction:
            gemini_bounding_box = prediction["box_2d"]
            # Crop the image using the bounding box
            try:
                cropped_image = crop_image(image, gemini_bounding_box)
                cropped_images.append(cropped_image)
            except Exception as e:
                print(f"Error cropping image: {e}")

    return cropped_images


def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES):
    # Initialize the client with the provided API key
    client = genai.Client(api_key=api_key)
    generate_content_config = types.GenerateContentConfig(
        response_mime_type="application/json",
    )

    while True:
        try:
            response = client.models.generate_content(
                model="gemini-2.5-pro-exp-03-25",
                contents=[prompt, pil_image],
                config=generate_content_config,
            )
            print("Result", response.text)
            crop_images = make_crops(
                image=pil_image, predictions_json_txt=response.text
            )
            # crop_images_with_labels = [(img, "bob") for img in crop_images]  # For Gradio Gallery, you can add labels here if needed
            crop_images_with_labels = []
            for img in crop_images:
                prediction = predict_species(img)
                label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}"
                crop_images_with_labels.append((img, label))
            return response.text, crop_images_with_labels
        except errors.ServerError as e:
            tries -= 1
            if tries == 0:
                raise e
            print(f"Retrying... {e}")
            time.sleep(5)


# Define the Gradio interface
with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo:
    gr.Markdown("# Image Analysis with Gemini 2.5 Pro + BioCLIP")

    with gr.Row():
        with gr.Column():
            gr.Markdown("## Upload an image and enter a prompt to get predictions")
            api_key_input = gr.Textbox(
                label="Gemini API Key",
                placeholder="Enter your Gemini API key here...",
                type="password",
            )
            image_input = gr.Image(label="Upload an image", type="pil")
            gr.Markdown("The prompt below must request bounding boxes.")
            prompt_input = gr.TextArea(
                label="Enter your prompt",
                placeholder="Describe what you want to analyze...",
                value=DEFAULT_PROMPT,
            )
            submit_btn = gr.Button("Analyze")

        with gr.Column():
            gr.Markdown("## Gemini Results")
            output = gr.JSON(label="Predictions")
            gr.Markdown("## Cropped Images with BioCLIP Predictions")
            image_gallery = gr.Gallery(label="Images", show_label=True)

    submit_btn.click(
        fn=generate_content_str,
        inputs=[api_key_input, prompt_input, image_input],
        outputs=[output, image_gallery],
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()