chansung commited on
Commit
eba227e
1 Parent(s): cdcfc92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -28
app.py CHANGED
@@ -1,18 +1,12 @@
1
  import gradio as gr
2
 
 
 
3
  from matplotlib import gridspec
4
  import matplotlib.pyplot as plt
5
- import numpy as np
6
- from PIL import Image
7
- import tensorflow as tf
8
- from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
9
 
10
- feature_extractor = SegformerFeatureExtractor.from_pretrained(
11
- "nvidia/segformer-b5-finetuned-ade-640-640"
12
- )
13
- model = TFSegformerForSemanticSegmentation.from_pretrained(
14
- "nvidia/segformer-b5-finetuned-ade-640-640"
15
- )
16
 
17
  def ade_palette():
18
  """ADE20K palette that maps each class to RGB values."""
@@ -169,14 +163,17 @@ def ade_palette():
169
  [92, 0, 255],
170
  ]
171
 
 
172
  labels_list = []
 
 
 
 
173
 
174
  with open(r'labels.txt', 'r') as fp:
175
  for line in fp:
176
  labels_list.append(line[:-1])
177
 
178
- colormap = np.asarray(ade_palette())
179
-
180
  def label_to_color_image(label):
181
  if label.ndim != 2:
182
  raise ValueError("Expect 2-D input label")
@@ -199,7 +196,7 @@ def draw_plot(pred_img, seg):
199
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
200
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
201
 
202
- unique_labels = np.unique(seg.numpy().astype("uint8"))
203
  ax = plt.subplot(grid_spec[1])
204
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
205
  ax.yaxis.tick_right()
@@ -209,18 +206,17 @@ def draw_plot(pred_img, seg):
209
  return fig
210
 
211
  def sepia(input_img):
212
- input_img = Image.fromarray(input_img)
213
-
214
- inputs = feature_extractor(images=input_img, return_tensors="tf")
215
- outputs = model(**inputs)
216
- logits = outputs.logits
217
-
218
- logits = tf.transpose(logits, [0, 2, 3, 1])
219
- logits = tf.image.resize(
220
- logits, input_img.size[::-1]
221
- ) # We reverse the shape of `image` because `image.size` returns width and height.
222
- seg = tf.math.argmax(logits, axis=-1)[0]
223
-
224
  color_seg = np.zeros(
225
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
226
  ) # height, width, 3
@@ -232,16 +228,15 @@ def sepia(input_img):
232
  color_seg = color_seg[..., ::-1]
233
 
234
  # Show image + mask
235
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
236
  pred_img = pred_img.astype(np.uint8)
237
 
238
  fig = draw_plot(pred_img, seg)
239
  return fig
240
 
241
  demo = gr.Interface(sepia,
242
- gr.Image(shape=(200, 200)),
243
  outputs=['plot'],
244
- # examples=["ADE_val_00000001.jpeg"],
245
  allow_flagging='never')
246
 
247
  demo.launch()
 
1
  import gradio as gr
2
 
3
+ import numpy as np
4
+ import cv2
5
  from matplotlib import gridspec
6
  import matplotlib.pyplot as plt
7
+ import onnxruntime as ort
 
 
 
8
 
9
+ import wget
 
 
 
 
 
10
 
11
  def ade_palette():
12
  """ADE20K palette that maps each class to RGB values."""
 
163
  [92, 0, 255],
164
  ]
165
 
166
+ url='https://github.com/deep-diver/segformer-tf-transformers/releases/download/1.0/segformer-b5-finetuned-ade-640-640.onnx'
167
  labels_list = []
168
+ colormap = np.asarray(ade_palette())
169
+
170
+ model_path = wget.download(url)
171
+ sess = ort.InferenceSession(model_path)
172
 
173
  with open(r'labels.txt', 'r') as fp:
174
  for line in fp:
175
  labels_list.append(line[:-1])
176
 
 
 
177
  def label_to_color_image(label):
178
  if label.ndim != 2:
179
  raise ValueError("Expect 2-D input label")
 
196
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
197
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
198
 
199
+ unique_labels = np.unique(seg)
200
  ax = plt.subplot(grid_spec[1])
201
  plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
202
  ax.yaxis.tick_right()
 
206
  return fig
207
 
208
  def sepia(input_img):
209
+ img = cv2.imread(input_img).astype(np.float32)
210
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
211
+ img_batch = np.expand_dims(img, axis=0)
212
+ img_batch = np.transpose(img_batch, (0, 3, 1, 2))
213
+
214
+ logits = sess.run(None, {"pixel_values": img_batch})[0]
215
+
216
+ logits = np.transpose(logits, (0, 2, 3, 1))
217
+ seg = np.argmax(logits, axis=-1)[0].astype('float32')
218
+ seg = cv2.resize(seg, (640, 640)).astype('uint8')
219
+
 
220
  color_seg = np.zeros(
221
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
222
  ) # height, width, 3
 
228
  color_seg = color_seg[..., ::-1]
229
 
230
  # Show image + mask
231
+ pred_img = img * 0.5 + color_seg * 0.5
232
  pred_img = pred_img.astype(np.uint8)
233
 
234
  fig = draw_plot(pred_img, seg)
235
  return fig
236
 
237
  demo = gr.Interface(sepia,
238
+ gr.inputs.Image(type="filepath", shape=(640, 640)),
239
  outputs=['plot'],
 
240
  allow_flagging='never')
241
 
242
  demo.launch()