Spaces:
Runtime error
Runtime error
Update app to support confidence adjustment and mask options
Browse filesUpdated the app.py and requirements.txt to introduce a slide bar for adjusting confidence threshold and adding a solid mask annotator in addition to the semi-transparent mask annotator. Also, modifications have been made to include the confidence parameter in the inference function call. The gradio version has been pinned to 3.50.2 in requirements.txt for consistency. These changes were made to offer users a better interaction experience and greater control over the image processing results.
- app.py +44 -19
- requirements.txt +1 -1
app.py
CHANGED
@@ -14,11 +14,12 @@ This is the demo for a Open Vocabulary Image Segmentation using
|
|
14 |
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
|
15 |
"""
|
16 |
EXAMPLES = [
|
17 |
-
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog"],
|
18 |
-
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building"],
|
19 |
-
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket"],
|
|
|
20 |
]
|
21 |
-
|
22 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
SAM_GENERATOR = pipeline(
|
24 |
task="mask-generation",
|
@@ -26,9 +27,13 @@ SAM_GENERATOR = pipeline(
|
|
26 |
device=DEVICE)
|
27 |
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
|
28 |
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
|
29 |
-
|
30 |
color=sv.Color.red(),
|
31 |
color_lookup=sv.ColorLookup.INDEX)
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
|
@@ -54,9 +59,13 @@ def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
|
|
54 |
return np.where(mask[..., None], image, gray_color)
|
55 |
|
56 |
|
57 |
-
def annotate(
|
|
|
|
|
|
|
|
|
58 |
img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
|
59 |
-
annotated_bgr_image =
|
60 |
scene=img_bgr_numpy, detections=detections)
|
61 |
return Image.fromarray(annotated_bgr_image[:, :, ::-1])
|
62 |
|
@@ -64,7 +73,8 @@ def annotate(image_rgb_pil: Image.Image, detections: sv.Detections) -> Image.Ima
|
|
64 |
def filter_detections(
|
65 |
image_rgb_pil: Image.Image,
|
66 |
detections: sv.Detections,
|
67 |
-
prompt: str
|
|
|
68 |
) -> sv.Detections:
|
69 |
img_rgb_numpy = np.array(image_rgb_pil)
|
70 |
text = [f"a picture of {prompt}", "a picture of background"]
|
@@ -76,27 +86,38 @@ def filter_detections(
|
|
76 |
masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
|
77 |
masked_crop_pil = Image.fromarray(masked_crop)
|
78 |
probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
|
79 |
-
|
80 |
-
filtering_mask.append(lass_index == 0)
|
81 |
|
82 |
filtering_mask = np.array(filtering_mask)
|
83 |
return detections[filtering_mask]
|
84 |
|
85 |
|
86 |
-
def inference(
|
|
|
|
|
|
|
|
|
87 |
width, height = image_rgb_pil.size
|
88 |
area = width * height
|
89 |
|
90 |
detections = run_sam(image_rgb_pil)
|
91 |
-
detections = detections[detections.area / area >
|
92 |
detections = filter_detections(
|
93 |
image_rgb_pil=image_rgb_pil,
|
94 |
detections=detections,
|
95 |
-
prompt=prompt
|
|
|
96 |
|
|
|
97 |
return [
|
98 |
-
annotate(
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
]
|
101 |
|
102 |
|
@@ -104,15 +125,19 @@ with gr.Blocks() as demo:
|
|
104 |
gr.Markdown(MARKDOWN)
|
105 |
with gr.Row():
|
106 |
with gr.Column():
|
107 |
-
input_image = gr.Image(
|
108 |
-
|
|
|
|
|
|
|
|
|
109 |
submit_button = gr.Button("Submit")
|
110 |
gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
|
111 |
with gr.Row():
|
112 |
gr.Examples(
|
113 |
examples=EXAMPLES,
|
114 |
fn=inference,
|
115 |
-
inputs=[input_image, prompt_text],
|
116 |
outputs=[gallery],
|
117 |
cache_examples=True,
|
118 |
run_on_click=True
|
@@ -120,7 +145,7 @@ with gr.Blocks() as demo:
|
|
120 |
|
121 |
submit_button.click(
|
122 |
inference,
|
123 |
-
inputs=[input_image, prompt_text],
|
124 |
outputs=gallery)
|
125 |
|
126 |
demo.launch(debug=False, show_error=True)
|
|
|
14 |
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
|
15 |
"""
|
16 |
EXAMPLES = [
|
17 |
+
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
|
18 |
+
["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
|
19 |
+
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
|
20 |
+
["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
|
21 |
]
|
22 |
+
MIN_AREA_THRESHOLD = 0.01
|
23 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
SAM_GENERATOR = pipeline(
|
25 |
task="mask-generation",
|
|
|
27 |
device=DEVICE)
|
28 |
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
|
29 |
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
|
30 |
+
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
|
31 |
color=sv.Color.red(),
|
32 |
color_lookup=sv.ColorLookup.INDEX)
|
33 |
+
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
|
34 |
+
color=sv.Color.red(),
|
35 |
+
color_lookup=sv.ColorLookup.INDEX,
|
36 |
+
opacity=1)
|
37 |
|
38 |
|
39 |
def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
|
|
|
59 |
return np.where(mask[..., None], image, gray_color)
|
60 |
|
61 |
|
62 |
+
def annotate(
|
63 |
+
image_rgb_pil: Image.Image,
|
64 |
+
detections: sv.Detections,
|
65 |
+
annotator: sv.MaskAnnotator
|
66 |
+
) -> Image.Image:
|
67 |
img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
|
68 |
+
annotated_bgr_image = annotator.annotate(
|
69 |
scene=img_bgr_numpy, detections=detections)
|
70 |
return Image.fromarray(annotated_bgr_image[:, :, ::-1])
|
71 |
|
|
|
73 |
def filter_detections(
|
74 |
image_rgb_pil: Image.Image,
|
75 |
detections: sv.Detections,
|
76 |
+
prompt: str,
|
77 |
+
confidence: float
|
78 |
) -> sv.Detections:
|
79 |
img_rgb_numpy = np.array(image_rgb_pil)
|
80 |
text = [f"a picture of {prompt}", "a picture of background"]
|
|
|
86 |
masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
|
87 |
masked_crop_pil = Image.fromarray(masked_crop)
|
88 |
probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
|
89 |
+
filtering_mask.append(probs[0][0] > confidence)
|
|
|
90 |
|
91 |
filtering_mask = np.array(filtering_mask)
|
92 |
return detections[filtering_mask]
|
93 |
|
94 |
|
95 |
+
def inference(
|
96 |
+
image_rgb_pil: Image.Image,
|
97 |
+
prompt: str,
|
98 |
+
confidence: float
|
99 |
+
) -> List[Image.Image]:
|
100 |
width, height = image_rgb_pil.size
|
101 |
area = width * height
|
102 |
|
103 |
detections = run_sam(image_rgb_pil)
|
104 |
+
detections = detections[detections.area / area > MIN_AREA_THRESHOLD]
|
105 |
detections = filter_detections(
|
106 |
image_rgb_pil=image_rgb_pil,
|
107 |
detections=detections,
|
108 |
+
prompt=prompt,
|
109 |
+
confidence=confidence)
|
110 |
|
111 |
+
blank_image = Image.new("RGB", (width, height), "black")
|
112 |
return [
|
113 |
+
annotate(
|
114 |
+
image_rgb_pil=image_rgb_pil,
|
115 |
+
detections=detections,
|
116 |
+
annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
|
117 |
+
annotate(
|
118 |
+
image_rgb_pil=blank_image,
|
119 |
+
detections=detections,
|
120 |
+
annotator=SOLID_MASK_ANNOTATOR)
|
121 |
]
|
122 |
|
123 |
|
|
|
125 |
gr.Markdown(MARKDOWN)
|
126 |
with gr.Row():
|
127 |
with gr.Column():
|
128 |
+
input_image = gr.Image(
|
129 |
+
image_mode='RGB', type='pil', height=500)
|
130 |
+
prompt_text = gr.Textbox(
|
131 |
+
label="Prompt", value="dog")
|
132 |
+
confidence_slider = gr.Slider(
|
133 |
+
label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6)
|
134 |
submit_button = gr.Button("Submit")
|
135 |
gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
|
136 |
with gr.Row():
|
137 |
gr.Examples(
|
138 |
examples=EXAMPLES,
|
139 |
fn=inference,
|
140 |
+
inputs=[input_image, prompt_text, confidence_slider],
|
141 |
outputs=[gallery],
|
142 |
cache_examples=True,
|
143 |
run_on_click=True
|
|
|
145 |
|
146 |
submit_button.click(
|
147 |
inference,
|
148 |
+
inputs=[input_image, prompt_text, confidence_slider],
|
149 |
outputs=gallery)
|
150 |
|
151 |
demo.launch(debug=False, show_error=True)
|
requirements.txt
CHANGED
@@ -4,6 +4,6 @@ torchvision
|
|
4 |
|
5 |
numpy
|
6 |
pillow
|
7 |
-
gradio
|
8 |
transformers
|
9 |
supervision
|
|
|
4 |
|
5 |
numpy
|
6 |
pillow
|
7 |
+
gradio==gradio==3.50.2
|
8 |
transformers
|
9 |
supervision
|