hyo37009 commited on
Commit
2692a45
1 Parent(s): e18d0fd
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -45,10 +45,7 @@ def greet(input_img):
45
  outputs = model(**inputs)
46
  logits = outputs.logits
47
 
48
- logits_np = logits.detach().numpy()
49
-
50
- logits_tf = tf.convert_to_tensor(logits_np)
51
- logits_tf = tf.transpose(logits_tf, [0, 2, 3, 1])
52
 
53
  logits_tf = tf.image.resize(
54
  logits_tf, input_img.size[::-1]
@@ -57,7 +54,7 @@ def greet(input_img):
57
 
58
  color_seg = np.zeros(
59
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
60
- )
61
  for label, color in enumerate(colormap):
62
  color_seg[seg.numpy() == label, :] = color
63
 
@@ -67,6 +64,7 @@ def greet(input_img):
67
  fig = draw_plot(pred_img, seg)
68
  return fig
69
 
 
70
  def draw_plot(pred_img, seg):
71
  fig = plt.figure(figsize=(20, 15))
72
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
45
  outputs = model(**inputs)
46
  logits = outputs.logits
47
 
48
+ logits_tf = tf.transpose(logits, [0, 2, 3, 1])
 
 
 
49
 
50
  logits_tf = tf.image.resize(
51
  logits_tf, input_img.size[::-1]
 
54
 
55
  color_seg = np.zeros(
56
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
57
+ ) # height, width, 3
58
  for label, color in enumerate(colormap):
59
  color_seg[seg.numpy() == label, :] = color
60
 
 
64
  fig = draw_plot(pred_img, seg)
65
  return fig
66
 
67
+
68
  def draw_plot(pred_img, seg):
69
  fig = plt.figure(figsize=(20, 15))
70
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])