glenn-jocher
commited on
Commit
•
5ba1de0
1
Parent(s):
38f5c1a
update experimental.py with Ensemble() module
Browse files- models/experimental.py +12 -0
models/experimental.py
CHANGED
@@ -107,3 +107,15 @@ class MixConv2d(nn.Module):
|
|
107 |
|
108 |
def forward(self, x):
|
109 |
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
def forward(self, x):
|
109 |
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
110 |
+
|
111 |
+
|
112 |
+
class Ensemble(nn.ModuleList):
|
113 |
+
# Ensemble of models
|
114 |
+
def __init__(self):
|
115 |
+
super(Ensemble, self).__init__()
|
116 |
+
|
117 |
+
def forward(self, x, augment=False):
|
118 |
+
y = []
|
119 |
+
for module in self:
|
120 |
+
y.append(module(x, augment)[0])
|
121 |
+
return torch.cat(y, 1), None # ensembled inference output, train output
|