SkalskiP commited on
Commit
23cb925
1 Parent(s): 3bd34d6
Files changed (1) hide show
  1. app.py +78 -8
app.py CHANGED
@@ -1,39 +1,109 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import supervision as sv
4
  import torch
5
  from PIL import Image
6
- from transformers import pipeline
 
 
 
 
 
 
 
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  SAM_GENERATOR = pipeline(
10
  task="mask-generation",
11
  model="facebook/sam-vit-large",
12
  device=DEVICE)
 
 
 
 
 
13
 
14
 
15
- def run_segmentation(image_rgb_pil: Image.Image) -> sv.Detections:
16
  outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
17
  mask = np.array(outputs['masks'])
18
  return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
19
 
20
 
21
- def inference(image_rgb_pil: Image.Image) -> Image.Image:
22
- detections = run_segmentation(image_rgb_pil)
23
- mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
25
  img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
26
- annotated_bgr_image = mask_annotator.annotate(
27
  scene=img_bgr_numpy, detections=detections)
28
  return Image.fromarray(annotated_bgr_image[:, :, ::-1])
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with gr.Blocks() as demo:
 
32
  with gr.Row():
33
- input_image = gr.Image(image_mode='RGB', type='pil')
 
 
34
  result_image = gr.Image(image_mode='RGB', type='pil')
35
  submit_button = gr.Button("Submit")
36
 
37
- submit_button.click(inference, inputs=[input_image], outputs=result_image)
 
 
 
38
 
39
  demo.launch(debug=False)
 
1
+ from typing import List
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
8
+ from transformers import pipeline, CLIPProcessor, CLIPModel
9
+
10
+ MARKDOWN = """
11
+ # Segment Anything Model + MetaCLIP
12
+ This is the demo for a Open Vocabulary Image Segmentation using
13
+ [Segment Anything Model](https://github.com/facebookresearch/segment-anything) and
14
+ [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
15
+ """
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  SAM_GENERATOR = pipeline(
19
  task="mask-generation",
20
  model="facebook/sam-vit-large",
21
  device=DEVICE)
22
+ CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
23
+ CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
24
+ MASK_ANNOTATOR = sv.MaskAnnotator(
25
+ color=sv.Color.red(),
26
+ color_lookup=sv.ColorLookup.INDEX)
27
 
28
 
29
+ def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
30
  outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
31
  mask = np.array(outputs['masks'])
32
  return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
33
 
34
 
35
+ def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray:
36
+ inputs = CLIP_PROCESSOR(
37
+ text=text,
38
+ images=image_rgb_pil,
39
+ return_tensors="pt",
40
+ padding=True
41
+ ).to(DEVICE)
42
+ outputs = CLIP_MODEL(**inputs)
43
+ probs = outputs.logits_per_image.softmax(dim=1)
44
+ return probs.detach().cpu().numpy()
45
+
46
+
47
+ def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
48
+ gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
49
+ return np.where(mask[..., None], image, gray_color)
50
 
51
+
52
+ def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Image:
53
  img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
54
+ annotated_bgr_image = MASK_ANNOTATOR.annotate(
55
  scene=img_bgr_numpy, detections=detections)
56
  return Image.fromarray(annotated_bgr_image[:, :, ::-1])
57
 
58
 
59
+ def filter_detections(
60
+ image_rgb_pil: Image.Image,
61
+ detections: sv.Detections,
62
+ prompt: str
63
+ ) -> sv.Detections:
64
+ img_rgb_numpy = np.array(image_rgb_pil)
65
+ text = [f"a picture of {prompt}", "a picture of background"]
66
+ filtering_mask = []
67
+
68
+ for xyxy, mask in zip(detections.xyxy, detections.mask):
69
+ crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy)
70
+ mask_crop = sv.crop_image(image=mask, xyxy=xyxy)
71
+ masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
72
+ masked_crop_pil = Image.fromarray(masked_crop)
73
+ probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
74
+ lass_index = np.argmax(probs)
75
+ filtering_mask.append(lass_index == 0)
76
+
77
+ filtering_mask = np.array(filtering_mask)
78
+ return detections[filtering_mask]
79
+
80
+
81
+ def inference(image_rgb_pil: Image.Image, prompt: str) -> Image.Image:
82
+ width, height = image_rgb_pil.size
83
+ area = width * height
84
+
85
+ detections = run_sam(image_rgb_pil)
86
+ detections = detections[detections.area / area > 0.005]
87
+ detections = filter_detections(
88
+ image_rgb_pil=image_rgb_pil,
89
+ detections=detections,
90
+ prompt=prompt)
91
+
92
+ return annotate(image_rgb_pil=image_rgb_pil, detections=detections)
93
+
94
+
95
  with gr.Blocks() as demo:
96
+ gr.Markdown(MARKDOWN)
97
  with gr.Row():
98
+ with gr.Column():
99
+ input_image = gr.Image(image_mode='RGB', type='pil')
100
+ prompt_text = gr.Textbox(label="Prompt", value="dog")
101
  result_image = gr.Image(image_mode='RGB', type='pil')
102
  submit_button = gr.Button("Submit")
103
 
104
+ submit_button.click(
105
+ inference,
106
+ inputs=[input_image, prompt_text],
107
+ outputs=result_image)
108
 
109
  demo.launch(debug=False)