han-byeol commited on
Commit
b60307f
1 Parent(s): 3c5e90d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -87,32 +87,30 @@ def sepia(input_img):
87
  ) # We reverse the shape of `image` because `image.size` returns width and height.
88
  seg = tf.math.argmax(logits, axis=-1)[0]
89
 
90
- color_seg = np.zeros(
91
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
92
- ) # height, width, 3
93
- for label, color in enumerate(colormap):
94
- color_seg[seg.numpy() == label, :] = color
95
-
96
- # Show image + mask
97
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
98
- pred_img = pred_img.astype(np.uint8)
99
-
100
- fig = draw_plot(pred_img, seg)
101
- return fig
102
-
103
  # Obtain probabilities
104
  probabilities = tf.nn.softmax(logits, axis=-1)
105
 
106
  # Visualize probabilities as bar plot
107
- plt.figure(figsize=(12,6))
108
  class_names = labels_list
109
  y_pos = np.arange(len(class_names))
110
- plt.bart(y_pos, probabilities.numpy().mean(axis=(0, 1)), align='center')
111
- plt.ytickes(y_pos, class_names)
 
112
  plt.xlabel('Probability')
113
  plt.title('Class Probabilities')
114
 
115
- return plt.gcf()
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  demo = gr.Interface(fn=sepia,
118
  inputs=gr.Image(shape=(400, 600)),
 
87
  ) # We reverse the shape of `image` because `image.size` returns width and height.
88
  seg = tf.math.argmax(logits, axis=-1)[0]
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # Obtain probabilities
91
  probabilities = tf.nn.softmax(logits, axis=-1)
92
 
93
  # Visualize probabilities as bar plot
 
94
  class_names = labels_list
95
  y_pos = np.arange(len(class_names))
96
+ plt.figure(figsize=(12, 6))
97
+ plt.barh(y_pos, probabilities.numpy().mean(axis=(0, 1)), align='center')
98
+ plt.yticks(y_pos, class_names)
99
  plt.xlabel('Probability')
100
  plt.title('Class Probabilities')
101
 
102
+ # Obtain colored segmentation mask
103
+ color_seg = label_to_color_image(seg.numpy())
104
+
105
+ # Show image + mask + probability bar plot
106
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
107
+ pred_img = pred_img.astype(np.uint8)
108
+
109
+ fig, ax = plt.subplots(1, 2, figsize=(18, 9))
110
+ ax[0].imshow(pred_img)
111
+ ax[0].axis('off')
112
+ ax[1].imshow(plt.gca().get_children()[1].get_children()[0].get_children()[2].get_children()[0].get_children()[0].get_children()[0].get_children()[1].get_children()[1].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[1].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[1].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].get_children()[0].
113
+
114
 
115
  demo = gr.Interface(fn=sepia,
116
  inputs=gr.Image(shape=(400, 600)),