Spaces:
Sleeping
Sleeping
import gradio as gr | |
import PIL.Image | |
import torch | |
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
class Detector: | |
def __init__(self, model_id: str): | |
self.device = DEVICE | |
self.processor = AutoProcessor.from_pretrained(model_id) | |
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( | |
self.device | |
) | |
def detect( | |
self, | |
image: PIL.Image.Image, | |
text_labels: list[str], | |
threshold: float = 0.4, | |
): | |
inputs = self.processor( | |
images=image, text=[text_labels], return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
results = self.processor.post_process_grounded_object_detection( | |
outputs, threshold=threshold, target_sizes=[(image.height, image.width)] | |
) | |
detections = [] | |
result = results[0] | |
for box, score, labels in zip( | |
result["boxes"], result["scores"], result["text_labels"] | |
): | |
box = [round(x, 2) for x in box.tolist()] | |
detections.append( | |
dict( | |
label=labels, | |
confidence=round(score.item(), 3), | |
box=box, | |
) | |
) | |
return detections | |
models = dict( | |
tiny=Detector("iSEE-Laboratory/llmdet_tiny"), | |
base=Detector("iSEE-Laboratory/llmdet_base"), | |
large=Detector("iSEE-Laboratory/llmdet_large"), | |
) | |
def _postprocess(detections): | |
annotations = [] | |
for detection in detections: | |
box = detection["box"] | |
mask = (int(box[0]), int(box[1]), int(box[2]), int(box[3])) | |
label = f"{detection['label']} ({detection['confidence']:.2f})" | |
annotations.append((mask, label)) | |
return annotations | |
def detect_objects(image, labels, confidence_threshold): | |
labels = [label.strip() for label in labels.split(",")] | |
detections = [] | |
for model_name in models.keys(): | |
detection = models[model_name].detect( | |
image, | |
labels, | |
threshold=confidence_threshold, | |
) | |
detections.append(_postprocess(detection)) | |
return tuple((image, det) for det in detections) | |
with gr.Blocks(delete_cache=(5, 10)) as demo: | |
gr.Markdown( | |
"# LLMDet Arena ✨\n ### [Paper](https://arxiv.org/abs/2501.18954) - [Repository](https://github.com/iSEE-Laboratory/LLMDet)" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Input Image") | |
image_input = gr.Image(type="pil", image_mode="RGB", format="jpeg") | |
with gr.Column(): | |
gr.Markdown("## Settings") | |
confidence_slider = gr.Slider( | |
0, | |
1, | |
value=0.3, | |
step=0.01, | |
interactive=True, | |
label="Confidence threshold:", | |
) | |
labels = ["a cat", "a remote control"] | |
text_input = gr.Textbox( | |
label="Object labels (comma separated):", | |
placeholder=",".join(labels), | |
lines=1, | |
) | |
with gr.Row(): | |
detect_button = gr.Button("Detect Objects") | |
with gr.Row(): | |
gr.Markdown("## Output Annotated Images") | |
with gr.Row(): | |
output_annotated_image_tiny = gr.AnnotatedImage(label="TINY", format="jpeg") | |
output_annotated_image_base = gr.AnnotatedImage(label="BASE", format="jpeg") | |
output_annotated_image_large = gr.AnnotatedImage(label="LARGE", format="jpeg") | |
# Connect the button to the detection function | |
detect_button.click( | |
fn=detect_objects, | |
inputs=[image_input, text_input, confidence_slider], | |
outputs=[ | |
output_annotated_image_tiny, | |
output_annotated_image_base, | |
output_annotated_image_large, | |
], | |
) | |
with gr.Row(): | |
gr.Markdown("## Examples") | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
[ | |
"http://images.cocodataset.org/val2017/000000039769.jpg", | |
"a cat, a remote control", | |
0.3, | |
], | |
[ | |
"http://images.cocodataset.org/val2017/000000370486.jpg", | |
"a person", | |
0.3, | |
], | |
[ | |
"http://images.cocodataset.org/train2017/000000345263.jpg", | |
"a red apple, a green apple", | |
0.3, | |
], | |
], | |
inputs=[image_input, text_input, confidence_slider], | |
outputs=[ | |
output_annotated_image_tiny, | |
output_annotated_image_base, | |
output_annotated_image_large, | |
], | |
fn=detect_objects, | |
cache_examples=True, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |