glenn-jocher
commited on
Commit
·
02445d1
1
Parent(s):
c80b249
improved model.yaml source tracking
Browse files- detect.py +1 -1
- models/yolo.py +19 -13
detect.py
CHANGED
@@ -128,7 +128,7 @@ def detect(save_img=False):
|
|
128 |
|
129 |
if save_txt or save_img:
|
130 |
print('Results saved to %s' % os.getcwd() + os.sep + out)
|
131 |
-
if platform == 'darwin': # MacOS
|
132 |
os.system('open ' + save_path)
|
133 |
|
134 |
print('Done. (%.3fs)' % (time.time() - t0))
|
|
|
128 |
|
129 |
if save_txt or save_img:
|
130 |
print('Results saved to %s' % os.getcwd() + os.sep + out)
|
131 |
+
if platform == 'darwin' and not opt.update: # MacOS
|
132 |
os.system('open ' + save_path)
|
133 |
|
134 |
print('Done. (%.3fs)' % (time.time() - t0))
|
models/yolo.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import argparse
|
|
|
2 |
|
3 |
from models.experimental import *
|
4 |
|
@@ -43,20 +44,21 @@ class Detect(nn.Module):
|
|
43 |
|
44 |
|
45 |
class Model(nn.Module):
|
46 |
-
def __init__(self,
|
47 |
super(Model, self).__init__()
|
48 |
-
if
|
49 |
-
self.
|
50 |
else: # is *.yaml
|
51 |
import yaml # for torch hub
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
# Define model
|
56 |
-
if nc and nc != self.
|
57 |
-
print('Overriding %s nc=%g with nc=%g' % (
|
58 |
-
self.
|
59 |
-
self.model, self.save = parse_model(self.
|
60 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
61 |
|
62 |
# Build strides, anchors
|
@@ -148,17 +150,21 @@ class Model(nn.Module):
|
|
148 |
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
|
149 |
m.bn = None # remove batchnorm
|
150 |
m.forward = m.fuseforward # update forward
|
151 |
-
|
152 |
return self
|
153 |
|
154 |
-
def
|
|
|
|
|
|
|
|
|
155 |
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
156 |
-
anchors, nc, gd, gw =
|
157 |
na = (len(anchors[0]) // 2) # number of anchors
|
158 |
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
159 |
|
160 |
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
161 |
-
for i, (f, n, m, args) in enumerate(
|
162 |
m = eval(m) if isinstance(m, str) else m # eval strings
|
163 |
for j, a in enumerate(args):
|
164 |
try:
|
|
|
1 |
import argparse
|
2 |
+
from copy import deepcopy
|
3 |
|
4 |
from models.experimental import *
|
5 |
|
|
|
44 |
|
45 |
|
46 |
class Model(nn.Module):
|
47 |
+
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
|
48 |
super(Model, self).__init__()
|
49 |
+
if isinstance(cfg, dict):
|
50 |
+
self.yaml = cfg # model dict
|
51 |
else: # is *.yaml
|
52 |
import yaml # for torch hub
|
53 |
+
self.yaml_file = Path(cfg).name
|
54 |
+
with open(cfg) as f:
|
55 |
+
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
56 |
|
57 |
# Define model
|
58 |
+
if nc and nc != self.yaml['nc']:
|
59 |
+
print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
|
60 |
+
self.yaml['nc'] = nc # override yaml value
|
61 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
|
62 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
63 |
|
64 |
# Build strides, anchors
|
|
|
150 |
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
|
151 |
m.bn = None # remove batchnorm
|
152 |
m.forward = m.fuseforward # update forward
|
153 |
+
self.info()
|
154 |
return self
|
155 |
|
156 |
+
def info(self): # print model information
|
157 |
+
torch_utils.model_info(self)
|
158 |
+
|
159 |
+
|
160 |
+
def parse_model(d, ch): # model_dict, input_channels(3)
|
161 |
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
|
162 |
+
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
|
163 |
na = (len(anchors[0]) // 2) # number of anchors
|
164 |
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
165 |
|
166 |
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
167 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
168 |
m = eval(m) if isinstance(m, str) else m # eval strings
|
169 |
for j, a in enumerate(args):
|
170 |
try:
|