Karin0616 commited on
Commit
2bee93c
1 Parent(s): ae23d32
Files changed (1) hide show
  1. app.py +4 -14
app.py CHANGED
@@ -77,14 +77,7 @@ def draw_plot(pred_img, seg):
77
  ax.tick_params(width=0.0, labelsize=25)
78
  return fig
79
 
80
- def sepia(input_img, *label_buttons):
81
- selected_color = None
82
- for label, button_state in zip(labels_list, labels_list):
83
- if button_state:
84
- label_index = label_buttons.index(label)
85
- selected_color = colormap[label_index]
86
- break
87
-
88
  input_img = Image.fromarray(input_img)
89
 
90
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -100,12 +93,9 @@ def sepia(input_img, *label_buttons):
100
  color_seg = np.zeros(
101
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
102
  ) # height, width, 3
103
- if selected_color:
104
- label = colormap.index(selected_color)
105
- color_seg[seg.numpy() == label, :] = selected_color
106
- '''for label, color in enumerate(colormap):
107
  color_seg[seg.numpy() == label, :] = color
108
- '''
109
  # Show image + mask
110
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
111
  pred_img = pred_img.astype(np.uint8)
@@ -114,7 +104,7 @@ def sepia(input_img, *label_buttons):
114
  return fig
115
 
116
  demo = gr.Interface(fn=sepia,
117
- inputs=[gr.Image(shape=(564, 846)),gr.ButtonGroup([gr.Button(label) for label in labels_list], label="Select a label")],
118
  outputs=['plot'],
119
  live=True,
120
  examples=["city1.jpg","city2.jpg","city3.jpg"],
 
77
  ax.tick_params(width=0.0, labelsize=25)
78
  return fig
79
 
80
+ def sepia(input_img):
 
 
 
 
 
 
 
81
  input_img = Image.fromarray(input_img)
82
 
83
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
93
  color_seg = np.zeros(
94
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
95
  ) # height, width, 3
96
+ for label, color in enumerate(colormap):
 
 
 
97
  color_seg[seg.numpy() == label, :] = color
98
+
99
  # Show image + mask
100
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
101
  pred_img = pred_img.astype(np.uint8)
 
104
  return fig
105
 
106
  demo = gr.Interface(fn=sepia,
107
+ inputs=gr.Image(shape=(564,846)),
108
  outputs=['plot'],
109
  live=True,
110
  examples=["city1.jpg","city2.jpg","city3.jpg"],