hmdliu commited on
Commit
54b7544
โ€ข
1 Parent(s): b6879dc

Add mask guidance

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +82 -25
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Sidewalks Seg
3
- emoji: ๐Ÿ“Š
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: Sidewalks Seg
3
+ emoji: ๐Ÿšถโ€โ™€๏ธ
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import gradio as gr
3
  import matplotlib.pyplot as plt
4
  from PIL import Image
@@ -6,47 +7,103 @@ from transformers import SamModel, SamProcessor
6
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
9
- model = SamModel.from_pretrained('hmdliu/sidewalks-seg-base')
10
  model.to(device)
11
 
12
- def segment_image(image, threshold):
13
- # init data
14
- width, height = image.size
15
- prompt = [0, 0, width, height]
 
16
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
17
  # make prediction
18
  outputs = model(pixel_values=inputs['pixel_values'].to(device),
19
  input_boxes=inputs['input_boxes'].to(device),
20
  multimask_output=False)
21
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
22
- prediction = (prob_map > threshold).float()
23
- prob_map, prediction = prob_map.numpy(), prediction.numpy()
24
  # visualize results
25
- save_image(image, 'image.png')
26
- save_image(prob_map, 'prob.png', cmap='jet')
27
- save_image(prediction, 'mask.png', cmap='gray')
28
- return Image.open('image.png'), Image.open('mask.png'), Image.open('prob.png')
 
 
 
 
 
 
 
29
 
30
- def save_image(image, path, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  plt.figure(figsize=(8, 8))
32
- plt.imshow(image, interpolation='nearest', **kwargs)
33
  plt.axis('off')
34
  plt.tight_layout()
35
- plt.savefig(path, bbox_inches='tight', pad_inches=0)
36
  plt.close()
 
 
 
37
 
38
  with gr.Blocks() as demo:
39
- with gr.Row():
40
- with gr.Column():
41
- image_input = gr.Image(type='pil', label='TIFF Image')
42
- threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
43
- segment_button = gr.Button('Segment')
44
- with gr.Column():
45
- prediction = gr.Image(type='pil', label='Segmentation Result')
46
- prob_map = gr.Image(type='pil', label='Probability Map')
47
- segment_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  segment_image,
49
- inputs=[image_input, threshold_slider],
50
- outputs=[image_input, prediction, prob_map]
 
 
 
 
 
51
  )
52
  demo.launch(debug=True, show_error=True)
 
1
  import torch
2
+ import numpy as np
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
 
7
 
8
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
10
+ model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
11
  model.to(device)
12
 
13
+ def segment_image(image, threshold, x_min, y_min, x_max, y_max):
14
+ # tolerate TIFF image input
15
+ image.save('image.png')
16
+ # init input data
17
+ prompt = [x_min, y_min, x_max, y_max]
18
  inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
19
  # make prediction
20
  outputs = model(pixel_values=inputs['pixel_values'].to(device),
21
  input_boxes=inputs['input_boxes'].to(device),
22
  multimask_output=False)
23
  prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
24
+ pred_mask = (prob_map > threshold).float().numpy()
 
25
  # visualize results
26
+ plt.figure(figsize=(8, 8))
27
+ plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
28
+ plt.axis('off')
29
+ plt.tight_layout()
30
+ plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
31
+ plt.close()
32
+ # post-processing
33
+ ret_image = Image.open('image.png')
34
+ ret_pred = (Image.open('image.png'), [(pred_mask, 'Sidewalks')])
35
+ ret_prob = Image.open('prob.png')
36
+ return ret_image, ret_pred, ret_prob
37
 
38
+ def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y_max):
39
+ # tolerate TIFF image input
40
+ image['background'].save('image.png')
41
+ # init input data
42
+ prompt = [x_min, y_min, x_max, y_max]
43
+ img = Image.open('image.png').convert('RGB')
44
+ inputs = processor(img, input_boxes=[[prompt]], return_tensors='pt')
45
+ # make prediction
46
+ outputs = model(pixel_values=inputs['pixel_values'].to(device),
47
+ input_boxes=inputs['input_boxes'].to(device),
48
+ multimask_output=False)
49
+ prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
50
+ # perform mask guidance
51
+ guidance_mask = (np.max(np.array(image['layers'][0]), axis=2) != 0).astype(float)
52
+ enhance_map = prob_map.numpy() + offset * guidance_mask
53
+ pred_mask = (enhance_map > threshold).astype(float)
54
+ # visualize results
55
  plt.figure(figsize=(8, 8))
56
+ plt.imshow(enhance_map, cmap='jet', interpolation='nearest')
57
  plt.axis('off')
58
  plt.tight_layout()
59
+ plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
60
  plt.close()
61
+ # post-processing
62
+ regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
63
+ return (image['background'], regions), Image.open('prob.png')
64
 
65
  with gr.Blocks() as demo:
66
+ with gr.Tab('Baseline'):
67
+ with gr.Row():
68
+ with gr.Column():
69
+ t1_input = gr.Image(type='pil', label='Input Image')
70
+ with gr.Row():
71
+ t1_x_min = gr.Textbox(value=0, label='x_min')
72
+ t1_y_min = gr.Textbox(value=0, label='y_min')
73
+ t1_x_max = gr.Textbox(value=256, label='x_max')
74
+ t1_y_max = gr.Textbox(value=256, label='y_max')
75
+ t1_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
76
+ t1_segment = gr.Button('Segment')
77
+ with gr.Column():
78
+ t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
79
+ with gr.Column():
80
+ t1_prob_map = gr.Image(type='pil', label='Probability Map')
81
+ with gr.Tab('Mask Guidance'):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ t2_input = gr.ImageEditor(type='pil', crop_size='2:1', label='Input Image',
85
+ brush=gr.Brush(default_size='5', color_mode='fixed'),
86
+ sources=['upload'], transforms=[])
87
+ with gr.Row():
88
+ t2_x_min = gr.Textbox(value=0, label='x_min')
89
+ t2_y_min = gr.Textbox(value=0, label='y_min')
90
+ t2_x_max = gr.Textbox(value=256, label='x_max')
91
+ t2_y_max = gr.Textbox(value=256, label='y_max')
92
+ t2_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
93
+ t2_offset = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4, label='Guidance Offset')
94
+ t2_segment = gr.Button('Segment')
95
+ with gr.Column():
96
+ t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
97
+ with gr.Column():
98
+ t2_prob_map = gr.Image(type='pil', label='Probability Map')
99
+ t1_segment.click(
100
  segment_image,
101
+ inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
102
+ outputs=[t1_input, t1_pred, t1_prob_map]
103
+ )
104
+ t2_segment.click(
105
+ segment_image_with_guidance,
106
+ inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
107
+ outputs=[t2_pred, t2_prob_map]
108
  )
109
  demo.launch(debug=True, show_error=True)