glenn-jocher commited on
Commit
02445d1
·
1 Parent(s): c80b249

improved model.yaml source tracking

Browse files
Files changed (2) hide show
  1. detect.py +1 -1
  2. models/yolo.py +19 -13
detect.py CHANGED
@@ -128,7 +128,7 @@ def detect(save_img=False):
128
 
129
  if save_txt or save_img:
130
  print('Results saved to %s' % os.getcwd() + os.sep + out)
131
- if platform == 'darwin': # MacOS
132
  os.system('open ' + save_path)
133
 
134
  print('Done. (%.3fs)' % (time.time() - t0))
 
128
 
129
  if save_txt or save_img:
130
  print('Results saved to %s' % os.getcwd() + os.sep + out)
131
+ if platform == 'darwin' and not opt.update: # MacOS
132
  os.system('open ' + save_path)
133
 
134
  print('Done. (%.3fs)' % (time.time() - t0))
models/yolo.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
 
3
  from models.experimental import *
4
 
@@ -43,20 +44,21 @@ class Detect(nn.Module):
43
 
44
 
45
  class Model(nn.Module):
46
- def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
47
  super(Model, self).__init__()
48
- if type(model_cfg) is dict:
49
- self.md = model_cfg # model dict
50
  else: # is *.yaml
51
  import yaml # for torch hub
52
- with open(model_cfg) as f:
53
- self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
 
54
 
55
  # Define model
56
- if nc and nc != self.md['nc']:
57
- print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc))
58
- self.md['nc'] = nc # override yaml value
59
- self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
60
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
61
 
62
  # Build strides, anchors
@@ -148,17 +150,21 @@ class Model(nn.Module):
148
  m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
149
  m.bn = None # remove batchnorm
150
  m.forward = m.fuseforward # update forward
151
- torch_utils.model_info(self)
152
  return self
153
 
154
- def parse_model(md, ch): # model_dict, input_channels(3)
 
 
 
 
155
  print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
156
- anchors, nc, gd, gw = md['anchors'], md['nc'], md['depth_multiple'], md['width_multiple']
157
  na = (len(anchors[0]) // 2) # number of anchors
158
  no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
159
 
160
  layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
161
- for i, (f, n, m, args) in enumerate(md['backbone'] + md['head']): # from, number, module, args
162
  m = eval(m) if isinstance(m, str) else m # eval strings
163
  for j, a in enumerate(args):
164
  try:
 
1
  import argparse
2
+ from copy import deepcopy
3
 
4
  from models.experimental import *
5
 
 
44
 
45
 
46
  class Model(nn.Module):
47
+ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
48
  super(Model, self).__init__()
49
+ if isinstance(cfg, dict):
50
+ self.yaml = cfg # model dict
51
  else: # is *.yaml
52
  import yaml # for torch hub
53
+ self.yaml_file = Path(cfg).name
54
+ with open(cfg) as f:
55
+ self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
56
 
57
  # Define model
58
+ if nc and nc != self.yaml['nc']:
59
+ print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
60
+ self.yaml['nc'] = nc # override yaml value
61
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
62
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
63
 
64
  # Build strides, anchors
 
150
  m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
151
  m.bn = None # remove batchnorm
152
  m.forward = m.fuseforward # update forward
153
+ self.info()
154
  return self
155
 
156
+ def info(self): # print model information
157
+ torch_utils.model_info(self)
158
+
159
+
160
+ def parse_model(d, ch): # model_dict, input_channels(3)
161
  print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
162
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
163
  na = (len(anchors[0]) // 2) # number of anchors
164
  no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
165
 
166
  layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
167
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
168
  m = eval(m) if isinstance(m, str) else m # eval strings
169
  for j, a in enumerate(args):
170
  try: