Spaces:
Sleeping
Sleeping
Add mask guidance
Browse files
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Sidewalks Seg
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Sidewalks Seg
|
3 |
+
emoji: ๐ถโโ๏ธ
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
import gradio as gr
|
3 |
import matplotlib.pyplot as plt
|
4 |
from PIL import Image
|
@@ -6,47 +7,103 @@ from transformers import SamModel, SamProcessor
|
|
6 |
|
7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
8 |
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
|
9 |
-
model = SamModel.from_pretrained('hmdliu/sidewalks-seg
|
10 |
model.to(device)
|
11 |
|
12 |
-
def segment_image(image, threshold):
|
13 |
-
#
|
14 |
-
|
15 |
-
|
|
|
16 |
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
|
17 |
# make prediction
|
18 |
outputs = model(pixel_values=inputs['pixel_values'].to(device),
|
19 |
input_boxes=inputs['input_boxes'].to(device),
|
20 |
multimask_output=False)
|
21 |
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
|
22 |
-
|
23 |
-
prob_map, prediction = prob_map.numpy(), prediction.numpy()
|
24 |
# visualize results
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
plt.figure(figsize=(8, 8))
|
32 |
-
plt.imshow(
|
33 |
plt.axis('off')
|
34 |
plt.tight_layout()
|
35 |
-
plt.savefig(
|
36 |
plt.close()
|
|
|
|
|
|
|
37 |
|
38 |
with gr.Blocks() as demo:
|
39 |
-
with gr.
|
40 |
-
with gr.
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
segment_image,
|
49 |
-
inputs=[
|
50 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
51 |
)
|
52 |
demo.launch(debug=True, show_error=True)
|
|
|
1 |
import torch
|
2 |
+
import numpy as np
|
3 |
import gradio as gr
|
4 |
import matplotlib.pyplot as plt
|
5 |
from PIL import Image
|
|
|
7 |
|
8 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
|
10 |
+
model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
|
11 |
model.to(device)
|
12 |
|
13 |
+
def segment_image(image, threshold, x_min, y_min, x_max, y_max):
|
14 |
+
# tolerate TIFF image input
|
15 |
+
image.save('image.png')
|
16 |
+
# init input data
|
17 |
+
prompt = [x_min, y_min, x_max, y_max]
|
18 |
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
|
19 |
# make prediction
|
20 |
outputs = model(pixel_values=inputs['pixel_values'].to(device),
|
21 |
input_boxes=inputs['input_boxes'].to(device),
|
22 |
multimask_output=False)
|
23 |
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
|
24 |
+
pred_mask = (prob_map > threshold).float().numpy()
|
|
|
25 |
# visualize results
|
26 |
+
plt.figure(figsize=(8, 8))
|
27 |
+
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
|
28 |
+
plt.axis('off')
|
29 |
+
plt.tight_layout()
|
30 |
+
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
|
31 |
+
plt.close()
|
32 |
+
# post-processing
|
33 |
+
ret_image = Image.open('image.png')
|
34 |
+
ret_pred = (Image.open('image.png'), [(pred_mask, 'Sidewalks')])
|
35 |
+
ret_prob = Image.open('prob.png')
|
36 |
+
return ret_image, ret_pred, ret_prob
|
37 |
|
38 |
+
def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y_max):
|
39 |
+
# tolerate TIFF image input
|
40 |
+
image['background'].save('image.png')
|
41 |
+
# init input data
|
42 |
+
prompt = [x_min, y_min, x_max, y_max]
|
43 |
+
img = Image.open('image.png').convert('RGB')
|
44 |
+
inputs = processor(img, input_boxes=[[prompt]], return_tensors='pt')
|
45 |
+
# make prediction
|
46 |
+
outputs = model(pixel_values=inputs['pixel_values'].to(device),
|
47 |
+
input_boxes=inputs['input_boxes'].to(device),
|
48 |
+
multimask_output=False)
|
49 |
+
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
|
50 |
+
# perform mask guidance
|
51 |
+
guidance_mask = (np.max(np.array(image['layers'][0]), axis=2) != 0).astype(float)
|
52 |
+
enhance_map = prob_map.numpy() + offset * guidance_mask
|
53 |
+
pred_mask = (enhance_map > threshold).astype(float)
|
54 |
+
# visualize results
|
55 |
plt.figure(figsize=(8, 8))
|
56 |
+
plt.imshow(enhance_map, cmap='jet', interpolation='nearest')
|
57 |
plt.axis('off')
|
58 |
plt.tight_layout()
|
59 |
+
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
|
60 |
plt.close()
|
61 |
+
# post-processing
|
62 |
+
regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
|
63 |
+
return (image['background'], regions), Image.open('prob.png')
|
64 |
|
65 |
with gr.Blocks() as demo:
|
66 |
+
with gr.Tab('Baseline'):
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column():
|
69 |
+
t1_input = gr.Image(type='pil', label='Input Image')
|
70 |
+
with gr.Row():
|
71 |
+
t1_x_min = gr.Textbox(value=0, label='x_min')
|
72 |
+
t1_y_min = gr.Textbox(value=0, label='y_min')
|
73 |
+
t1_x_max = gr.Textbox(value=256, label='x_max')
|
74 |
+
t1_y_max = gr.Textbox(value=256, label='y_max')
|
75 |
+
t1_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
|
76 |
+
t1_segment = gr.Button('Segment')
|
77 |
+
with gr.Column():
|
78 |
+
t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
|
79 |
+
with gr.Column():
|
80 |
+
t1_prob_map = gr.Image(type='pil', label='Probability Map')
|
81 |
+
with gr.Tab('Mask Guidance'):
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column():
|
84 |
+
t2_input = gr.ImageEditor(type='pil', crop_size='2:1', label='Input Image',
|
85 |
+
brush=gr.Brush(default_size='5', color_mode='fixed'),
|
86 |
+
sources=['upload'], transforms=[])
|
87 |
+
with gr.Row():
|
88 |
+
t2_x_min = gr.Textbox(value=0, label='x_min')
|
89 |
+
t2_y_min = gr.Textbox(value=0, label='y_min')
|
90 |
+
t2_x_max = gr.Textbox(value=256, label='x_max')
|
91 |
+
t2_y_max = gr.Textbox(value=256, label='y_max')
|
92 |
+
t2_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
|
93 |
+
t2_offset = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4, label='Guidance Offset')
|
94 |
+
t2_segment = gr.Button('Segment')
|
95 |
+
with gr.Column():
|
96 |
+
t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
|
97 |
+
with gr.Column():
|
98 |
+
t2_prob_map = gr.Image(type='pil', label='Probability Map')
|
99 |
+
t1_segment.click(
|
100 |
segment_image,
|
101 |
+
inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
|
102 |
+
outputs=[t1_input, t1_pred, t1_prob_map]
|
103 |
+
)
|
104 |
+
t2_segment.click(
|
105 |
+
segment_image_with_guidance,
|
106 |
+
inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
|
107 |
+
outputs=[t2_pred, t2_prob_map]
|
108 |
)
|
109 |
demo.launch(debug=True, show_error=True)
|