GroundingSAM / app.py
merve's picture
merve HF staff
Update app.py
3298147 verified
raw
history blame
No virus
3.4 kB
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import torch
from transformers import SamModel, SamProcessor
import spaces
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model_id = "IDEA-Research/grounding-dino-base"
dino_processor = AutoProcessor.from_pretrained(model_id)
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
def infer_dino(img, text_queries, score_threshold):
queries=""
for query in text_queries:
queries += f"{query}. "
width, height = img.shape[:2]
target_sizes=[(width, height)]
inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = dino_model(**inputs)
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
box_threshold=score_threshold,
target_sizes=target_sizes)
return results
@spaces.GPU
def query_image(img, text_queries, dino_threshold):
text_queries = text_queries
text_queries = text_queries.split(",")
dino_output = infer_dino(img, text_queries, dino_threshold)
result_labels=[]
for pred in dino_output:
boxes = pred["boxes"].cpu()
scores = pred["scores"].cpu()
labels = pred["labels"]
box = [torch.round(pred["boxes"][0], decimals=2), torch.round(pred["boxes"][1], decimals=2),
torch.round(pred["boxes"][2], decimals=2), torch.round(pred["boxes"][3], decimals=2)]
for box, score, label in zip(boxes, scores, labels):
if label != "":
inputs = sam_processor(
img,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
result_labels.append((mask, label))
return img, result_labels
import gradio as gr
description = "This Space combines [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-base), a bleeding-edge zero-shot object detection model with [SAM](https://huggingface.co/facebook/sam-vit-base), the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
query_image,
inputs=[gr.Image(label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold for GroundingDINO")],
outputs="annotatedimage",
title="GroundingDINO 🀝 SAM for Zero-shot Segmentation",
description=description,
examples=[
["./cats.png", "cat, fishnet", 0.16],["./bee.jpg", "bee, flower", 0.16]
],
)
demo.launch(debug=True)