import gradio as gr import PIL.Image import torch from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class Detector: def __init__(self, model_id: str): self.device = DEVICE self.processor = AutoProcessor.from_pretrained(model_id) self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( self.device ) def detect( self, image: PIL.Image.Image, text_labels: list[str], threshold: float = 0.4, ): inputs = self.processor( images=image, text=[text_labels], return_tensors="pt" ).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) results = self.processor.post_process_grounded_object_detection( outputs, threshold=threshold, target_sizes=[(image.height, image.width)] ) detections = [] result = results[0] for box, score, labels in zip( result["boxes"], result["scores"], result["text_labels"] ): box = [round(x, 2) for x in box.tolist()] detections.append( dict( label=labels, confidence=round(score.item(), 3), box=box, ) ) return detections models = dict( tiny=Detector("iSEE-Laboratory/llmdet_tiny"), base=Detector("iSEE-Laboratory/llmdet_base"), large=Detector("iSEE-Laboratory/llmdet_large"), ) def _postprocess(detections): annotations = [] for detection in detections: box = detection["box"] mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3])) label = f"{detection['label']} ({detection['confidence']:.2f})" annotations.append((mask, label)) return annotations def detect_objects(image, labels, confidence_threshold): labels = [label.strip() for label in labels.split(",")] detections = [] for model_name in models.keys(): detection = models[model_name].detect( image, labels, threshold=confidence_threshold, ) detections.append(_postprocess(detection)) return tuple((image, det) for det in detections) with gr.Blocks(delete_cache=(5, 10)) as demo: gr.Markdown( "# LLMDet Arena ✨\n ### [Paper](https://arxiv.org/abs/2501.18954) - [Repository](https://github.com/iSEE-Laboratory/LLMDet)" ) with gr.Row(): with gr.Column(): gr.Markdown("## Input Image") image_input = gr.Image(type="pil", image_mode="RGB", format="jpeg") with gr.Column(): gr.Markdown("## Settings") confidence_slider = gr.Slider( 0, 1, value=0.3, step=0.01, interactive=True, label="Confidence threshold:", ) labels = ["a cat", "a remote control"] text_input = gr.Textbox( label="Object labels (comma separated):", placeholder=",".join(labels), lines=1, ) with gr.Row(): detect_button = gr.Button("Detect Objects") with gr.Row(): gr.Markdown("## Output Annotated Images") with gr.Row(): output_annotated_image_tiny = gr.AnnotatedImage(label="TINY", format="jpeg") output_annotated_image_base = gr.AnnotatedImage(label="BASE", format="jpeg") output_annotated_image_large = gr.AnnotatedImage(label="LARGE", format="jpeg") # Connect the button to the detection function detect_button.click( fn=detect_objects, inputs=[image_input, text_input, confidence_slider], outputs=[ output_annotated_image_tiny, output_annotated_image_base, output_annotated_image_large, ], ) with gr.Row(): gr.Markdown("## Examples") with gr.Row(): gr.Examples( examples=[ [ "http://images.cocodataset.org/val2017/000000039769.jpg", "a cat, a remote control", 0.3, ], [ "http://images.cocodataset.org/val2017/000000370486.jpg", "a person", 0.3, ], [ "http://images.cocodataset.org/train2017/000000345263.jpg", "a red apple, a green apple", 0.3, ], ], inputs=[image_input, text_input, confidence_slider], outputs=[ output_annotated_image_tiny, output_annotated_image_base, output_annotated_image_large, ], fn=detect_objects, cache_examples=True, ) if __name__ == "__main__": demo.launch()