File size: 2,226 Bytes
9e9e5c4
 
 
 
 
 
 
 
 
72a36cb
9e9e5c4
 
72a36cb
 
9e9e5c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a36cb
 
 
9e9e5c4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from huggingface_hub import hf_hub_download

from transformers import pipeline
import torch
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import gradio as gr

hf_hub_download(repo_id = "merve/sam2-hiera-small", filename="sam2_hiera_small.pt", local_dir = "./")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT = f"./sam2_hiera_small.pt"
CONFIG = "sam2_hiera_s.yaml"

sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
predictor = SAM2ImagePredictor(sam2_model)

checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device="cuda")



def query(image, texts, threshold):
  texts = texts.split(",")

  predictions = detector(
    image,
    candidate_labels=texts,
    threshold=threshold
  )

  result_labels = []
  for pred in predictions:
    
    box = pred["box"]
    score = pred["score"]
    label = pred["label"]
    box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2), 
          round(pred["box"]["xmin"], 2)+round(pred["box"]["xmax"], 2), 
          round(pred["box"]["ymin"], 2)+round(pred["box"]["ymax"], 2)]
    predictor.set_image(image)

    mask, scores, logits = predictor.predict(box=box,
    multimask_output=False)
    mask = mask[np.newaxis, ...]
    result_labels.append((mask, label))
  return image, result_labels


description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM2, the state-of-the-art mask generation model. SAM2 normally doesn't accept text input. Combining SAM with OWLv2 makes SAM2 text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
    outputs="annotatedimage",
    title="OWLv2 🤝 SAMv2",
    description=description,
    examples=[
        ["./bird.jpg", "bird", 0.15],
        ["./buddha.JPG", "buddha", 0.15],
        
    ],
    cache_examples=True
)
demo.launch(debug=True)