import spaces from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection import torch import gradio as gr device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) @spaces.GPU def infer(img, text_queries, score_threshold, model): if model == "dino": queries="" for query in text_queries: queries += f"{query}. " width, height = img.shape[:2] target_sizes=[(width, height)] inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = dino_model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids, box_threshold=score_threshold, target_sizes=target_sizes) elif model == "owl": size = max(img.shape[:2]) target_sizes = torch.Tensor([[size, size]]) inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device) with torch.no_grad(): outputs = owl_model(**inputs) outputs.logits = outputs.logits.cpu() outputs.pred_boxes = outputs.pred_boxes.cpu() results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] result_labels = [] for box, score, label in zip(boxes, scores, labels): box = [int(i) for i in box.tolist()] if score < score_threshold: continue if model == "owl": label = text_queries[label.cpu().item()] result_labels.append((box, label)) return result_labels def query_image(img, text_queries, owl_threshold, dino_threshold): text_queries = text_queries text_queries = text_queries.split(",") owl_output = infer(img, text_queries, owl_threshold, "owl") dino_output = infer(img, text_queries, owl_threshold, "dino") return (img, owl_output), (img, dino_output) owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold") dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold") owl_output = gr.AnnotatedImage(label="OWL Output") dino_output = gr.AnnotatedImage(label="Grounding DINO Output") demo = gr.Interface( query_image, inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold], outputs=[owl_output, dino_output], title="OWLv2 ⚔ Grounding DINO", description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) and [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) in this Space. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.", examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]] ) demo.launch(debug=True)