hmdliu commited on
Commit
b6879dc
1 Parent(s): 722738e

Update layout

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -22,18 +22,18 @@ def segment_image(image, threshold):
22
  prediction = (prob_map > threshold).float()
23
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
24
  # visualize results
 
 
 
 
 
 
25
  plt.figure(figsize=(8, 8))
26
- plt.imshow(prediction, cmap='gray', interpolation='nearest')
27
- plt.axis('off')
28
- plt.tight_layout()
29
- plt.savefig('mask.png', bbox_inches='tight', pad_inches=0)
30
- plt.figure(figsize=(8, 8))
31
- plt.imshow(prob_map, cmap='jet', interpolation='nearest')
32
  plt.axis('off')
33
  plt.tight_layout()
34
- plt.savefig('heatmap.png', bbox_inches='tight', pad_inches=0)
35
  plt.close()
36
- return Image.open('mask.png'), Image.open('heatmap.png')
37
 
38
  with gr.Blocks() as demo:
39
  with gr.Row():
@@ -43,11 +43,10 @@ with gr.Blocks() as demo:
43
  segment_button = gr.Button('Segment')
44
  with gr.Column():
45
  prediction = gr.Image(type='pil', label='Segmentation Result')
46
- with gr.Column():
47
  prob_map = gr.Image(type='pil', label='Probability Map')
48
  segment_button.click(
49
  segment_image,
50
  inputs=[image_input, threshold_slider],
51
- outputs=[prediction, prob_map]
52
  )
53
  demo.launch(debug=True, show_error=True)
 
22
  prediction = (prob_map > threshold).float()
23
  prob_map, prediction = prob_map.numpy(), prediction.numpy()
24
  # visualize results
25
+ save_image(image, 'image.png')
26
+ save_image(prob_map, 'prob.png', cmap='jet')
27
+ save_image(prediction, 'mask.png', cmap='gray')
28
+ return Image.open('image.png'), Image.open('mask.png'), Image.open('prob.png')
29
+
30
+ def save_image(image, path, **kwargs):
31
  plt.figure(figsize=(8, 8))
32
+ plt.imshow(image, interpolation='nearest', **kwargs)
 
 
 
 
 
33
  plt.axis('off')
34
  plt.tight_layout()
35
+ plt.savefig(path, bbox_inches='tight', pad_inches=0)
36
  plt.close()
 
37
 
38
  with gr.Blocks() as demo:
39
  with gr.Row():
 
43
  segment_button = gr.Button('Segment')
44
  with gr.Column():
45
  prediction = gr.Image(type='pil', label='Segmentation Result')
 
46
  prob_map = gr.Image(type='pil', label='Probability Map')
47
  segment_button.click(
48
  segment_image,
49
  inputs=[image_input, threshold_slider],
50
+ outputs=[image_input, prediction, prob_map]
51
  )
52
  demo.launch(debug=True, show_error=True)