glenn-jocher commited on
Commit
0529b77
1 Parent(s): d5e363f

Update common.py lists for tuples (#7063)

Browse files
Files changed (1) hide show
  1. models/common.py +5 -5
models/common.py CHANGED
@@ -31,7 +31,7 @@ from utils.torch_utils import copy_attr, time_sync
31
  def autopad(k, p=None): # kernel, padding
32
  # Pad to 'same'
33
  if p is None:
34
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
35
  return p
36
 
37
 
@@ -133,7 +133,7 @@ class C3(nn.Module):
133
  self.cv2 = Conv(c1, c_, 1, 1)
134
  self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
135
  self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
136
- # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
137
 
138
  def forward(self, x):
139
  return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
@@ -194,7 +194,7 @@ class SPPF(nn.Module):
194
  warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
195
  y1 = self.m(x)
196
  y2 = self.m(y1)
197
- return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
198
 
199
 
200
  class Focus(nn.Module):
@@ -205,7 +205,7 @@ class Focus(nn.Module):
205
  # self.contract = Contract(gain=2)
206
 
207
  def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
208
- return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
209
  # return self.conv(self.contract(x))
210
 
211
 
@@ -219,7 +219,7 @@ class GhostConv(nn.Module):
219
 
220
  def forward(self, x):
221
  y = self.cv1(x)
222
- return torch.cat([y, self.cv2(y)], 1)
223
 
224
 
225
  class GhostBottleneck(nn.Module):
 
31
  def autopad(k, p=None): # kernel, padding
32
  # Pad to 'same'
33
  if p is None:
34
+ p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad
35
  return p
36
 
37
 
 
133
  self.cv2 = Conv(c1, c_, 1, 1)
134
  self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
135
  self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
136
+ # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
137
 
138
  def forward(self, x):
139
  return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
 
194
  warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
195
  y1 = self.m(x)
196
  y2 = self.m(y1)
197
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
198
 
199
 
200
  class Focus(nn.Module):
 
205
  # self.contract = Contract(gain=2)
206
 
207
  def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
208
+ return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
209
  # return self.conv(self.contract(x))
210
 
211
 
 
219
 
220
  def forward(self, x):
221
  y = self.cv1(x)
222
+ return torch.cat((y, self.cv2(y)), 1)
223
 
224
 
225
  class GhostBottleneck(nn.Module):