glenn-jocher commited on
Commit
e16e9e4
1 Parent(s): cb527d3

new nc=len(names) check

Browse files
Files changed (1) hide show
  1. train.py +2 -2
train.py CHANGED
@@ -76,7 +76,7 @@ def train(hyp):
76
  os.remove(f)
77
 
78
  # Create model
79
- model = Model(opt.cfg, nc=data_dict['nc']).to(device)
80
 
81
  # Image sizes
82
  gs = int(max(model.stride)) # grid size (max stride)
@@ -177,7 +177,7 @@ def train(hyp):
177
  model.hyp = hyp # attach hyperparameters to model
178
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
179
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
180
- model.names = data_dict['names']
181
 
182
  # Class frequency
183
  labels = np.concatenate(dataset.labels, 0)
 
76
  os.remove(f)
77
 
78
  # Create model
79
+ model = Model(opt.cfg, nc=nc).to(device)
80
 
81
  # Image sizes
82
  gs = int(max(model.stride)) # grid size (max stride)
 
177
  model.hyp = hyp # attach hyperparameters to model
178
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
179
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
180
+ model.names = names
181
 
182
  # Class frequency
183
  labels = np.concatenate(dataset.labels, 0)