glenn-jocher commited on
Commit
f8464b4
1 Parent(s): 7b833e3

Update yolo.py channel array (#2223)

Browse files
Files changed (1) hide show
  1. models/yolo.py +10 -25
models/yolo.py CHANGED
@@ -2,7 +2,6 @@ import argparse
2
  import logging
3
  import sys
4
  from copy import deepcopy
5
- from pathlib import Path
6
 
7
  sys.path.append('./') # to run '$ python *.py' files in subdirectories
8
  logger = logging.getLogger(__name__)
@@ -213,43 +212,27 @@ def parse_model(d, ch): # model_dict, input_channels(3)
213
  if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
214
  C3]:
215
  c1, c2 = ch[f], args[0]
216
-
217
- # Normal
218
- # if i > 0 and args[0] != no: # channel expansion factor
219
- # ex = 1.75 # exponential (default 2.0)
220
- # e = math.log(c2 / ch[1]) / math.log(2)
221
- # c2 = int(ch[1] * ex ** e)
222
- # if m != Focus:
223
-
224
- c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
225
-
226
- # Experimental
227
- # if i > 0 and args[0] != no: # channel expansion factor
228
- # ex = 1 + gw # exponential (default 2.0)
229
- # ch1 = 32 # ch[1]
230
- # e = math.log(c2 / ch1) / math.log(2) # level 1-n
231
- # c2 = int(ch1 * ex ** e)
232
- # if m != Focus:
233
- # c2 = make_divisible(c2, 8) if c2 != no else c2
234
 
235
  args = [c1, c2, *args[1:]]
236
  if m in [BottleneckCSP, C3]:
237
- args.insert(2, n)
238
  n = 1
239
  elif m is nn.BatchNorm2d:
240
  args = [ch[f]]
241
  elif m is Concat:
242
- c2 = sum([ch[x if x < 0 else x + 1] for x in f])
243
  elif m is Detect:
244
- args.append([ch[x + 1] for x in f])
245
  if isinstance(args[1], int): # number of anchors
246
  args[1] = [list(range(args[1] * 2))] * len(f)
247
  elif m is Contract:
248
- c2 = ch[f if f < 0 else f + 1] * args[0] ** 2
249
  elif m is Expand:
250
- c2 = ch[f if f < 0 else f + 1] // args[0] ** 2
251
  else:
252
- c2 = ch[f if f < 0 else f + 1]
253
 
254
  m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
255
  t = str(m)[8:-2].replace('__main__.', '') # module type
@@ -258,6 +241,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
258
  logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
259
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
260
  layers.append(m_)
 
 
261
  ch.append(c2)
262
  return nn.Sequential(*layers), sorted(save)
263
 
 
2
  import logging
3
  import sys
4
  from copy import deepcopy
 
5
 
6
  sys.path.append('./') # to run '$ python *.py' files in subdirectories
7
  logger = logging.getLogger(__name__)
 
212
  if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
213
  C3]:
214
  c1, c2 = ch[f], args[0]
215
+ if c2 != no: # if not output
216
+ c2 = make_divisible(c2 * gw, 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  args = [c1, c2, *args[1:]]
219
  if m in [BottleneckCSP, C3]:
220
+ args.insert(2, n) # number of repeats
221
  n = 1
222
  elif m is nn.BatchNorm2d:
223
  args = [ch[f]]
224
  elif m is Concat:
225
+ c2 = sum([ch[x] for x in f])
226
  elif m is Detect:
227
+ args.append([ch[x] for x in f])
228
  if isinstance(args[1], int): # number of anchors
229
  args[1] = [list(range(args[1] * 2))] * len(f)
230
  elif m is Contract:
231
+ c2 = ch[f] * args[0] ** 2
232
  elif m is Expand:
233
+ c2 = ch[f] // args[0] ** 2
234
  else:
235
+ c2 = ch[f]
236
 
237
  m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
238
  t = str(m)[8:-2].replace('__main__.', '') # module type
 
241
  logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
242
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
243
  layers.append(m_)
244
+ if i == 0:
245
+ ch = []
246
  ch.append(c2)
247
  return nn.Sequential(*layers), sorted(save)
248