LLMDet_Arena / app.py
Darius Morawiec
Add paper and repository links
66c25fa
import gradio as gr
import PIL.Image
import torch
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class Detector:
def __init__(self, model_id: str):
self.device = DEVICE
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
self.device
)
def detect(
self,
image: PIL.Image.Image,
text_labels: list[str],
threshold: float = 0.4,
):
inputs = self.processor(
images=image, text=[text_labels], return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs, threshold=threshold, target_sizes=[(image.height, image.width)]
)
detections = []
result = results[0]
for box, score, labels in zip(
result["boxes"], result["scores"], result["text_labels"]
):
box = [round(x, 2) for x in box.tolist()]
detections.append(
dict(
label=labels,
confidence=round(score.item(), 3),
box=box,
)
)
return detections
models = dict(
tiny=Detector("iSEE-Laboratory/llmdet_tiny"),
base=Detector("iSEE-Laboratory/llmdet_base"),
large=Detector("iSEE-Laboratory/llmdet_large"),
)
def _postprocess(detections):
annotations = []
for detection in detections:
box = detection["box"]
mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3]))
label = f"{detection['label']} ({detection['confidence']:.2f})"
annotations.append((mask, label))
return annotations
def detect_objects(image, labels, confidence_threshold):
labels = [label.strip() for label in labels.split(",")]
detections = []
for model_name in models.keys():
detection = models[model_name].detect(
image,
labels,
threshold=confidence_threshold,
)
detections.append(_postprocess(detection))
return tuple((image, det) for det in detections)
with gr.Blocks(delete_cache=(5, 10)) as demo:
gr.Markdown(
"# LLMDet Arena ✨\n ### [Paper](https://arxiv.org/abs/2501.18954) - [Repository](https://github.com/iSEE-Laboratory/LLMDet)"
)
with gr.Row():
with gr.Column():
gr.Markdown("## Input Image")
image_input = gr.Image(type="pil", image_mode="RGB", format="jpeg")
with gr.Column():
gr.Markdown("## Settings")
confidence_slider = gr.Slider(
0,
1,
value=0.3,
step=0.01,
interactive=True,
label="Confidence threshold:",
)
labels = ["a cat", "a remote control"]
text_input = gr.Textbox(
label="Object labels (comma separated):",
placeholder=",".join(labels),
lines=1,
)
with gr.Row():
detect_button = gr.Button("Detect Objects")
with gr.Row():
gr.Markdown("## Output Annotated Images")
with gr.Row():
output_annotated_image_tiny = gr.AnnotatedImage(label="TINY", format="jpeg")
output_annotated_image_base = gr.AnnotatedImage(label="BASE", format="jpeg")
output_annotated_image_large = gr.AnnotatedImage(label="LARGE", format="jpeg")
# Connect the button to the detection function
detect_button.click(
fn=detect_objects,
inputs=[image_input, text_input, confidence_slider],
outputs=[
output_annotated_image_tiny,
output_annotated_image_base,
output_annotated_image_large,
],
)
with gr.Row():
gr.Markdown("## Examples")
with gr.Row():
gr.Examples(
examples=[
[
"http://images.cocodataset.org/val2017/000000039769.jpg",
"a cat, a remote control",
0.3,
],
[
"http://images.cocodataset.org/val2017/000000370486.jpg",
"a person",
0.3,
],
[
"http://images.cocodataset.org/train2017/000000345263.jpg",
"a red apple, a green apple",
0.3,
],
],
inputs=[image_input, text_input, confidence_slider],
outputs=[
output_annotated_image_tiny,
output_annotated_image_base,
output_annotated_image_large,
],
fn=detect_objects,
cache_examples=True,
)
if __name__ == "__main__":
demo.launch()