glenn-jocher
commited on
Commit
•
bb3c346
1
Parent(s):
1fdaa49
model.yaml nc inherited from dataset.yaml
Browse files- models/yolo.py +2 -1
- train.py +1 -2
models/yolo.py
CHANGED
@@ -52,7 +52,8 @@ class Model(nn.Module):
|
|
52 |
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
53 |
|
54 |
# Define model
|
55 |
-
if nc:
|
|
|
56 |
self.md['nc'] = nc # override yaml value
|
57 |
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
58 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
|
|
52 |
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
53 |
|
54 |
# Define model
|
55 |
+
if nc and nc != self.md['nc']:
|
56 |
+
print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc))
|
57 |
self.md['nc'] = nc # override yaml value
|
58 |
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
59 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
train.py
CHANGED
@@ -77,8 +77,7 @@ def train(hyp):
|
|
77 |
os.remove(f)
|
78 |
|
79 |
# Create model
|
80 |
-
model = Model(opt.cfg).to(device)
|
81 |
-
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
82 |
|
83 |
# Image sizes
|
84 |
gs = int(max(model.stride)) # grid size (max stride)
|
|
|
77 |
os.remove(f)
|
78 |
|
79 |
# Create model
|
80 |
+
model = Model(opt.cfg, nc=data_dict['nc']).to(device)
|
|
|
81 |
|
82 |
# Image sizes
|
83 |
gs = int(max(model.stride)) # grid size (max stride)
|