Spaces:
Sleeping
Sleeping
File size: 4,881 Bytes
249c00e 49f5e8c 249c00e 49f5e8c 249c00e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import json
import time
import os
from google import genai
from google.genai import types
from google.genai import errors
from bioclip import TreeOfLifeClassifier, Rank
PROMPT_RETRYIES = 2
DEFAULT_PROMPT = """
Return bounding boxes and a description for each species in this image.
Ensure you only return valid JSON.
""".strip()
# Initialize classifier outside of functions
classifier = TreeOfLifeClassifier()
def crop_image(image, gemini_bounding_box):
"""
Crop the image based on the bounding box coordinates.
:param image: PIL Image object
:param bounding_box: Tuple of (y_min, x_min, y_max, x_max) in range 0-1000
:return: Cropped PIL Image
"""
width, height = image.size
y_min, x_min, y_max, x_max = gemini_bounding_box
# Convert normalized coordinates to pixel values
left = int(x_min / 1000 * width)
upper = int(y_min / 1000 * height)
right = int(x_max / 1000 * width)
lower = int(y_max / 1000 * height)
# Crop and return the image
return image.crop((left, upper, right, lower))
def predict_species(img):
predictions = classifier.predict([img], Rank.SPECIES, k=1)
return predictions[0]
def make_crops(image, predictions_json_txt):
"""
Process predictions to crop images based on bounding boxes.
:param image: PIL Image object
:param predictions: str of JSON List of prediction dictionaries containing bounding boxes
:return: List of cropped images
"""
cropped_images = []
try:
predictions_json_txt
predictions = json.loads(predictions_json_txt)
except json.JSONDecodeError as e:
print(str(e))
return [] # Return empty list if JSON parsing fails
for prediction in predictions:
if "box_2d" in prediction:
gemini_bounding_box = prediction["box_2d"]
# Crop the image using the bounding box
try:
cropped_image = crop_image(image, gemini_bounding_box)
cropped_images.append(cropped_image)
except Exception as e:
print(f"Error cropping image: {e}")
return cropped_images
def generate_content_str(api_key, prompt, pil_image, tries=PROMPT_RETRYIES):
# Initialize the client with the provided API key
client = genai.Client(api_key=api_key)
generate_content_config = types.GenerateContentConfig(
response_mime_type="application/json",
)
while True:
try:
response = client.models.generate_content(
model="gemini-2.5-pro-exp-03-25",
contents=[prompt, pil_image],
config=generate_content_config,
)
print("Result", response.text)
crop_images = make_crops(
image=pil_image, predictions_json_txt=response.text
)
# crop_images_with_labels = [(img, "bob") for img in crop_images] # For Gradio Gallery, you can add labels here if needed
crop_images_with_labels = []
for img in crop_images:
prediction = predict_species(img)
label = f"{prediction['common_name']} - {prediction['species']} - {round(prediction['score'],3)}"
crop_images_with_labels.append((img, label))
return response.text, crop_images_with_labels
except errors.ServerError as e:
tries -= 1
if tries == 0:
raise e
print(f"Retrying... {e}")
time.sleep(5)
# Define the Gradio interface
with gr.Blocks(title="Gemini 2.5 Pro Explore") as demo:
gr.Markdown("# Image Analysis with Gemini 2.5 Pro + BioCLIP")
with gr.Row():
with gr.Column():
gr.Markdown("## Upload an image and enter a prompt to get predictions")
api_key_input = gr.Textbox(
label="Gemini API Key",
placeholder="Enter your Gemini API key here...",
type="password",
)
image_input = gr.Image(label="Upload an image", type="pil")
gr.Markdown("The prompt below must request bounding boxes.")
prompt_input = gr.TextArea(
label="Enter your prompt",
placeholder="Describe what you want to analyze...",
value=DEFAULT_PROMPT,
)
submit_btn = gr.Button("Analyze")
with gr.Column():
gr.Markdown("## Gemini Results")
output = gr.JSON(label="Predictions")
gr.Markdown("## Cropped Images with BioCLIP Predictions")
image_gallery = gr.Gallery(label="Images", show_label=True)
submit_btn.click(
fn=generate_content_str,
inputs=[api_key_input, prompt_input, image_input],
outputs=[output, image_gallery],
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|