glenn-jocher commited on
Commit
5ba1de0
1 Parent(s): 38f5c1a

update experimental.py with Ensemble() module

Browse files
Files changed (1) hide show
  1. models/experimental.py +12 -0
models/experimental.py CHANGED
@@ -107,3 +107,15 @@ class MixConv2d(nn.Module):
107
 
108
  def forward(self, x):
109
  return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def forward(self, x):
109
  return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
110
+
111
+
112
+ class Ensemble(nn.ModuleList):
113
+ # Ensemble of models
114
+ def __init__(self):
115
+ super(Ensemble, self).__init__()
116
+
117
+ def forward(self, x, augment=False):
118
+ y = []
119
+ for module in self:
120
+ y.append(module(x, augment)[0])
121
+ return torch.cat(y, 1), None # ensembled inference output, train output