shengqiangShi commited on
Commit
e4dee6a
1 Parent(s): 15450dc

Application file

Browse files
app.py CHANGED
@@ -2,6 +2,29 @@ import torch
2
  import gradio as gr
3
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Use GPU if available
7
  if torch.cuda.is_available():
@@ -11,55 +34,82 @@ else:
11
 
12
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
13
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
 
 
 
14
 
15
  @spaces.GPU
16
- def query_image(img, text_queries, score_threshold):
17
- text_queries = text_queries
18
  text_queries = text_queries.split(",")
19
-
20
  size = max(img.shape[:2])
21
  target_sizes = torch.Tensor([[size, size]])
22
  inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
23
 
24
  with torch.no_grad():
25
- outputs = model(**inputs)
26
-
27
- outputs.logits = outputs.logits.cpu()
28
- outputs.pred_boxes = outputs.pred_boxes.cpu()
29
- results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
30
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
31
 
 
 
 
32
  result_labels = []
 
33
  for box, score, label in zip(boxes, scores, labels):
34
- box = [int(i) for i in box.tolist()]
35
- if score < score_threshold:
36
- continue
37
- result_labels.append((box, text_queries[label.item()]))
38
- return img, result_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  description = """
42
- Try this demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlv2">OWLv2</a>,
43
- introduced in <a href="https://arxiv.org/abs/2306.09683">Scaling Open-Vocabulary Object Detection</a>.
44
- \n\n Compared to OWLVIT, OWLv2 performs better both in yield and performance (average precision).
45
- You can use OWLv2 to query images with text descriptions of any object.
46
- To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
47
- can also use the score threshold slider to set a threshold to filter out low probability predictions.
48
- \n\nOWL-ViT is trained on text templates,
49
- hence you can get better predictions by querying the image with text templates used in training the original model: e.g. *"photo of a star-spangled banner"*,
50
- *"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data.
51
- \n\n<a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/zeroshot_object_detection_with_owlvit.ipynb">Colab demo</a>
52
  """
53
  demo = gr.Interface(
54
- query_image,
55
- inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
56
- outputs="annotatedimage",
57
- title="Zero-Shot Object Detection with OWLv2",
58
- description=description,
59
  examples=[
60
- ["assets/astronaut.png", "human face, rocket, star-spangled banner, nasa badge", 0.11],
61
- ["assets/coffee.png", "coffee mug, spoon, plate", 0.1],
62
- ["assets/butterflies.jpeg", "orange butterfly", 0.3],
 
63
  ],
64
  )
65
- demo.launch()
 
 
2
  import gradio as gr
3
  from transformers import Owlv2Processor, Owlv2ForObjectDetection
4
  import spaces
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import random
9
+ from transformers import SamModel, SamProcessor
10
+
11
+
12
+ def apply_colored_masks_on_image(image, masks):
13
+ if not isinstance(image, Image.Image):
14
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
15
+
16
+ image_rgba = image.convert("RGBA")
17
+
18
+ for i in range(masks.shape[0]):
19
+ mask = masks[i].squeeze().cpu().numpy()
20
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L')
21
+ color = tuple([random.randint(0, 255) for _ in range(3)] + [128])
22
+ colored_mask = Image.new("RGBA", image.size, color)
23
+ colored_mask.putalpha(mask_image)
24
+ image_rgba = Image.alpha_composite(image_rgba, colored_mask)
25
+
26
+ return image_rgba
27
+
28
 
29
  # Use GPU if available
30
  if torch.cuda.is_available():
 
34
 
35
  model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
36
  processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
37
+ model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
38
+ processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge")
39
+
40
 
41
  @spaces.GPU
42
+ def query_image(img, text_queries, score_threshold=0.5):
 
43
  text_queries = text_queries.split(",")
 
44
  size = max(img.shape[:2])
45
  target_sizes = torch.Tensor([[size, size]])
46
  inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
47
 
48
  with torch.no_grad():
49
+ model_outputs = model(**inputs)
50
+ model_outputs.logits = model_outputs.logits.cpu()
51
+ model_outputs.pred_boxes = model_outputs.pred_boxes.cpu()
52
+ results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes)
53
+
54
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
55
 
56
+
57
+ img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
58
+
59
  result_labels = []
60
+ result_boxes = []
61
  for box, score, label in zip(boxes, scores, labels):
62
+ if score >= score_threshold:
63
+ box = [int(i) for i in box.tolist()]
64
+ label_text = text_queries[label.item()]
65
+ result_labels.append((box, label_text))
66
+ result_boxes.append(box)
67
+
68
+ input_boxes_for_sam = [result_boxes]
69
+ sam_image = generate_image_with_sam(np.array(img_pil), input_boxes_for_sam)
70
+
71
+ return sam_image,result_labels
72
+
73
+
74
+ def generate_image_with_sam(img, boxes):
75
+ img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
76
+ inputs = processor_sam(img_pil, return_tensors="pt").to(device)
77
+
78
+ image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"])
79
+
80
+ inputs = processor_sam(img_pil, input_boxes=boxes, return_tensors="pt").to(device)
81
+ inputs["input_boxes"].shape
82
+ inputs.pop("pixel_values", None)
83
+ inputs.update({"image_embeddings": image_embeddings})
84
+
85
+ with torch.no_grad():
86
+ outputs = model_sam(**inputs, multimask_output=False)
87
+
88
+ masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
89
+ scores = outputs.iou_scores
90
+ print(type(scores))
91
+ print(scores.shape if hasattr(scores, 'shape') else scores)
92
+
93
+
94
+ SAM_image = apply_colored_masks_on_image(img_pil, masks[0])
95
+ return SAM_image
96
 
97
 
98
  description = """
99
+ Split anythings
 
 
 
 
 
 
 
 
 
100
  """
101
  demo = gr.Interface(
102
+ fn=query_image,
103
+ inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.5, label="Score Threshold")],
104
+ outputs=gr.AnnotatedImage(),
105
+ title="Zero-Shot Object Detection SV3",
106
+ description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.",
107
  examples=[
108
+ ["images/purple cell.png", "purple cells", 0.11],
109
+ ["images/dark_cell.png", "gray cells", 0.1],
110
+ ["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.1],
111
+
112
  ],
113
  )
114
+
115
+ demo.launch()
images/animals.png ADDED
images/dark_cell.png ADDED
images/purple cell.png ADDED
requirements.txt CHANGED
@@ -3,4 +3,6 @@ torch>=1.7.0
3
  torchvision>=0.8.1
4
  git+https://github.com/huggingface/transformers.git
5
  scipy
6
- spaces
 
 
 
3
  torchvision>=0.8.1
4
  git+https://github.com/huggingface/transformers.git
5
  scipy
6
+ spaces
7
+ matplotlib
8
+ pillow