|
import google.generativeai as genai |
|
from google.generativeai.types import HarmBlockThreshold, HarmCategory |
|
import gradio as gr |
|
from PIL import Image, ImageDraw, ImageFont |
|
import json |
|
|
|
|
|
async def get_bounding_boxes(prompt: str, image: str, api_key: str): |
|
system_prompt = """ |
|
You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else. |
|
Your response can also include multiple bounding boxes and their labels in the list. |
|
The values in the list should be integers. |
|
Here are some example responses: |
|
{ |
|
"explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.", |
|
"bounding_boxes": [ |
|
{"label": "dragon", "box": [ymin, xmin, ymax, xmax]} |
|
] |
|
} |
|
{ |
|
"explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.", |
|
"bounding_boxes": [ |
|
{"label": "apple", "box": [ymin, xmin, ymax, xmax]}, |
|
{"label": "tomato", "box": [ymin, xmin, ymax, xmax]} |
|
] |
|
} |
|
""".strip() |
|
|
|
prompt = f"Return the bounding boxes and labels of: {prompt}" |
|
|
|
messages = [ |
|
{"role": "user", "parts": [prompt, image]}, |
|
] |
|
|
|
genai.configure(api_key=api_key) |
|
|
|
generation_config = { |
|
"temperature": 1, |
|
"max_output_tokens": 8192, |
|
"response_mime_type": "application/json", |
|
} |
|
|
|
model = genai.GenerativeModel( |
|
model_name="gemini-1.5-flash", |
|
generation_config=generation_config, |
|
safety_settings={ |
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE |
|
}, |
|
system_instruction=system_prompt |
|
) |
|
|
|
try: |
|
response = await model.generate_content_async(messages) |
|
except Exception as e: |
|
if "API key not valid" in str(e): |
|
raise gr.Error( |
|
"Invalid API key. Please provide a valid Gemini API key.") |
|
elif "rate limit" in str(e).lower(): |
|
raise gr.Error("Rate limit exceeded for the API key.") |
|
else: |
|
raise gr.Error(f"Failed to generate content: {str(e)}") |
|
|
|
response_json = json.loads(response.text) |
|
|
|
explanation = response_json["explanation"] |
|
bounding_boxes = response_json["bounding_boxes"] |
|
|
|
return bounding_boxes, explanation |
|
|
|
|
|
async def adjust_bounding_box(bounding_boxes, image): |
|
width, height = image.size |
|
adjusted_boxes = [] |
|
for item in bounding_boxes: |
|
label = item["label"] |
|
ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]] |
|
xmin *= width |
|
xmax *= width |
|
ymin *= height |
|
ymax *= height |
|
adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]}) |
|
return adjusted_boxes |
|
|
|
|
|
async def process_image(image, text, api_key): |
|
if not api_key: |
|
raise gr.Error("Please provide a Gemini API key.") |
|
|
|
|
|
image = Image.open(image) |
|
|
|
|
|
bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key) |
|
|
|
|
|
adjusted_boxes = await adjust_bounding_box(bounding_boxes, image) |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.load_default(size=20) |
|
|
|
for item in adjusted_boxes: |
|
box = item["box"] |
|
label = item["label"] |
|
draw.rectangle(box, outline="red", width=3) |
|
|
|
draw.text((box[0], box[1] - 25), label, fill="red", font=font) |
|
|
|
|
|
adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes) |
|
|
|
return explanation, image, adjusted_boxes_str |
|
|
|
|
|
async def gradio_app(image, text, api_key): |
|
return await process_image(image, text, api_key) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_app, |
|
inputs=[ |
|
gr.Image(type="filepath"), |
|
gr.Textbox(label="Object(s) to detect", value="person"), |
|
gr.Textbox(label="Your Gemini API Key", type="password") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Explanation"), |
|
gr.Image(type="pil", label="Output Image"), |
|
gr.Textbox(label="Coordinates of the detected objects") |
|
], |
|
title="OBJECT DETECTOR ✨", |
|
description="Detect objects in images using the Gemini 1.5 Flash model.", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch() |
|
|