glenn-jocher
commited on
Commit
•
0529b77
1
Parent(s):
d5e363f
Update common.py lists for tuples (#7063)
Browse files- models/common.py +5 -5
models/common.py
CHANGED
@@ -31,7 +31,7 @@ from utils.torch_utils import copy_attr, time_sync
|
|
31 |
def autopad(k, p=None): # kernel, padding
|
32 |
# Pad to 'same'
|
33 |
if p is None:
|
34 |
-
p = k // 2 if isinstance(k, int) else
|
35 |
return p
|
36 |
|
37 |
|
@@ -133,7 +133,7 @@ class C3(nn.Module):
|
|
133 |
self.cv2 = Conv(c1, c_, 1, 1)
|
134 |
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
|
135 |
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
136 |
-
# self.m = nn.Sequential(*
|
137 |
|
138 |
def forward(self, x):
|
139 |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
@@ -194,7 +194,7 @@ class SPPF(nn.Module):
|
|
194 |
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
195 |
y1 = self.m(x)
|
196 |
y2 = self.m(y1)
|
197 |
-
return self.cv2(torch.cat(
|
198 |
|
199 |
|
200 |
class Focus(nn.Module):
|
@@ -205,7 +205,7 @@ class Focus(nn.Module):
|
|
205 |
# self.contract = Contract(gain=2)
|
206 |
|
207 |
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
208 |
-
return self.conv(torch.cat(
|
209 |
# return self.conv(self.contract(x))
|
210 |
|
211 |
|
@@ -219,7 +219,7 @@ class GhostConv(nn.Module):
|
|
219 |
|
220 |
def forward(self, x):
|
221 |
y = self.cv1(x)
|
222 |
-
return torch.cat(
|
223 |
|
224 |
|
225 |
class GhostBottleneck(nn.Module):
|
|
|
31 |
def autopad(k, p=None): # kernel, padding
|
32 |
# Pad to 'same'
|
33 |
if p is None:
|
34 |
+
p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
|
35 |
return p
|
36 |
|
37 |
|
|
|
133 |
self.cv2 = Conv(c1, c_, 1, 1)
|
134 |
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
|
135 |
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
136 |
+
# self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
|
137 |
|
138 |
def forward(self, x):
|
139 |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
|
|
194 |
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
195 |
y1 = self.m(x)
|
196 |
y2 = self.m(y1)
|
197 |
+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
198 |
|
199 |
|
200 |
class Focus(nn.Module):
|
|
|
205 |
# self.contract = Contract(gain=2)
|
206 |
|
207 |
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
208 |
+
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
|
209 |
# return self.conv(self.contract(x))
|
210 |
|
211 |
|
|
|
219 |
|
220 |
def forward(self, x):
|
221 |
y = self.cv1(x)
|
222 |
+
return torch.cat((y, self.cv2(y)), 1)
|
223 |
|
224 |
|
225 |
class GhostBottleneck(nn.Module):
|