hyo37009 commited on
Commit
e18d0fd
1 Parent(s): d2d6d64
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -45,20 +45,21 @@ def greet(input_img):
45
  outputs = model(**inputs)
46
  logits = outputs.logits
47
 
48
- # Use .detach().numpy() to convert PyTorch tensor to NumPy array
49
- logits = logits.detach().numpy()
50
- logits = np.transpose(logits, [0, 2, 3, 1])
51
 
52
- logits = np.resize(logits, input_img.size[::-1])
 
53
 
54
- seg = np.argmax(logits, axis=-1)[0]
 
 
 
55
 
56
  color_seg = np.zeros(
57
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
58
  )
59
-
60
  for label, color in enumerate(colormap):
61
- color_seg[seg == label, :] = color
62
 
63
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
64
  pred_img = pred_img.astype(np.uint8)
 
45
  outputs = model(**inputs)
46
  logits = outputs.logits
47
 
48
+ logits_np = logits.detach().numpy()
 
 
49
 
50
+ logits_tf = tf.convert_to_tensor(logits_np)
51
+ logits_tf = tf.transpose(logits_tf, [0, 2, 3, 1])
52
 
53
+ logits_tf = tf.image.resize(
54
+ logits_tf, input_img.size[::-1]
55
+ )
56
+ seg = tf.math.argmax(logits_tf, axis=-1)[0]
57
 
58
  color_seg = np.zeros(
59
  (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
60
  )
 
61
  for label, color in enumerate(colormap):
62
+ color_seg[seg.numpy() == label, :] = color
63
 
64
  pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
65
  pred_img = pred_img.astype(np.uint8)