Cherie Ho
commited on
Commit
•
283b3f6
1
Parent(s):
b684d11
adjust threshold for viz
Browse files- mapper/utils/viz_2d.py +3 -2
mapper/utils/viz_2d.py
CHANGED
@@ -83,8 +83,9 @@ def one_hot_argmax_to_rgb(y, num_class):
|
|
83 |
class_colors = class_colors.values()
|
84 |
class_colors = [torch.tensor(x).float() for x in class_colors]
|
85 |
|
86 |
-
|
87 |
-
argmaxed
|
|
|
88 |
# print(argmaxed.shape)
|
89 |
|
90 |
seg_rgb = torch.ones(
|
|
|
83 |
class_colors = class_colors.values()
|
84 |
class_colors = [torch.tensor(x).float() for x in class_colors]
|
85 |
|
86 |
+
threshold = 0.25
|
87 |
+
argmaxed = torch.argmax((y > threshold).float(), dim=1) # Take argmax
|
88 |
+
argmaxed[torch.all(y <= threshold, dim=1)] = num_class
|
89 |
# print(argmaxed.shape)
|
90 |
|
91 |
seg_rgb = torch.ones(
|