Cherie Ho commited on
Commit
283b3f6
1 Parent(s): b684d11

adjust threshold for viz

Browse files
Files changed (1) hide show
  1. mapper/utils/viz_2d.py +3 -2
mapper/utils/viz_2d.py CHANGED
@@ -83,8 +83,9 @@ def one_hot_argmax_to_rgb(y, num_class):
83
  class_colors = class_colors.values()
84
  class_colors = [torch.tensor(x).float() for x in class_colors]
85
 
86
- argmaxed = torch.argmax((y > 0.5).float(), dim=1) # Take argmax
87
- argmaxed[torch.all(y <= 0.5, dim=1)] = num_class
 
88
  # print(argmaxed.shape)
89
 
90
  seg_rgb = torch.ones(
 
83
  class_colors = class_colors.values()
84
  class_colors = [torch.tensor(x).float() for x in class_colors]
85
 
86
+ threshold = 0.25
87
+ argmaxed = torch.argmax((y > threshold).float(), dim=1) # Take argmax
88
+ argmaxed[torch.all(y <= threshold, dim=1)] = num_class
89
  # print(argmaxed.shape)
90
 
91
  seg_rgb = torch.ones(