app.py
CHANGED
@@ -45,20 +45,21 @@ def greet(input_img):
|
|
45 |
outputs = model(**inputs)
|
46 |
logits = outputs.logits
|
47 |
|
48 |
-
|
49 |
-
logits = logits.detach().numpy()
|
50 |
-
logits = np.transpose(logits, [0, 2, 3, 1])
|
51 |
|
52 |
-
|
|
|
53 |
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
color_seg = np.zeros(
|
57 |
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
|
58 |
)
|
59 |
-
|
60 |
for label, color in enumerate(colormap):
|
61 |
-
color_seg[seg == label, :] = color
|
62 |
|
63 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
64 |
pred_img = pred_img.astype(np.uint8)
|
|
|
45 |
outputs = model(**inputs)
|
46 |
logits = outputs.logits
|
47 |
|
48 |
+
logits_np = logits.detach().numpy()
|
|
|
|
|
49 |
|
50 |
+
logits_tf = tf.convert_to_tensor(logits_np)
|
51 |
+
logits_tf = tf.transpose(logits_tf, [0, 2, 3, 1])
|
52 |
|
53 |
+
logits_tf = tf.image.resize(
|
54 |
+
logits_tf, input_img.size[::-1]
|
55 |
+
)
|
56 |
+
seg = tf.math.argmax(logits_tf, axis=-1)[0]
|
57 |
|
58 |
color_seg = np.zeros(
|
59 |
(seg.shape[0], seg.shape[1], 3), dtype=np.uint8
|
60 |
)
|
|
|
61 |
for label, color in enumerate(colormap):
|
62 |
+
color_seg[seg.numpy() == label, :] = color
|
63 |
|
64 |
pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
|
65 |
pred_img = pred_img.astype(np.uint8)
|