KEN commited on
Commit
647223a
1 Parent(s): a544d59

`Ensemble()` visualize fix (#3973)

Browse files

* fix visualize error

* Revert "fix visualize error"

* add visualise profile

Files changed (1) hide show
  1. models/experimental.py +2 -2
models/experimental.py CHANGED
@@ -100,10 +100,10 @@ class Ensemble(nn.ModuleList):
100
  def __init__(self):
101
  super(Ensemble, self).__init__()
102
 
103
- def forward(self, x, augment=False):
104
  y = []
105
  for module in self:
106
- y.append(module(x, augment)[0])
107
  # y = torch.stack(y).max(0)[0] # max ensemble
108
  # y = torch.stack(y).mean(0) # mean ensemble
109
  y = torch.cat(y, 1) # nms ensemble
 
100
  def __init__(self):
101
  super(Ensemble, self).__init__()
102
 
103
+ def forward(self, x, augment=False, profile=False, visualize=False):
104
  y = []
105
  for module in self:
106
+ y.append(module(x, augment, profile, visualize)[0])
107
  # y = torch.stack(y).max(0)[0] # max ensemble
108
  # y = torch.stack(y).mean(0) # mean ensemble
109
  y = torch.cat(y, 1) # nms ensemble