glenn-jocher commited on
Commit
e02a189
·
unverified ·
2 Parent(s): 13f6977 f02481c

Merge pull request #245 from yxNONG/patch-2

Browse files

Unify the check point of single and multi GPU

Files changed (2) hide show
  1. train.py +2 -2
  2. utils/torch_utils.py +13 -7
train.py CHANGED
@@ -79,7 +79,6 @@ def train(hyp):
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
82
- model.names = data_dict['names']
83
 
84
  # Image sizes
85
  gs = int(max(model.stride)) # grid size (max stride)
@@ -178,6 +177,7 @@ def train(hyp):
178
  model.hyp = hyp # attach hyperparameters to model
179
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
180
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
 
181
 
182
  # Class frequency
183
  labels = np.concatenate(dataset.labels, 0)
@@ -294,7 +294,7 @@ def train(hyp):
294
  batch_size=batch_size,
295
  imgsz=imgsz_test,
296
  save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
297
- model=ema.ema,
298
  single_cls=opt.single_cls,
299
  dataloader=testloader)
300
 
 
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
 
82
 
83
  # Image sizes
84
  gs = int(max(model.stride)) # grid size (max stride)
 
177
  model.hyp = hyp # attach hyperparameters to model
178
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
179
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
180
+ model.names = data_dict['names']
181
 
182
  # Class frequency
183
  labels = np.concatenate(dataset.labels, 0)
 
294
  batch_size=batch_size,
295
  imgsz=imgsz_test,
296
  save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
297
+ model=ema.ema.module if hasattr(model, 'module') else ema.ema,
298
  single_cls=opt.single_cls,
299
  dataloader=testloader)
300
 
utils/torch_utils.py CHANGED
@@ -54,6 +54,11 @@ def time_synchronized():
54
  return time.time()
55
 
56
 
 
 
 
 
 
57
  def initialize_weights(model):
58
  for m in model.modules():
59
  t = type(m)
@@ -111,8 +116,8 @@ def model_info(model, verbose=False):
111
 
112
  try: # FLOPS
113
  from thop import profile
114
- macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
115
- fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
116
  except:
117
  fs = ''
118
 
@@ -185,7 +190,7 @@ class ModelEMA:
185
  self.updates += 1
186
  d = self.decay(self.updates)
187
  with torch.no_grad():
188
- if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
189
  msd, esd = model.module.state_dict(), self.ema.module.state_dict()
190
  else:
191
  msd, esd = model.state_dict(), self.ema.state_dict()
@@ -196,7 +201,8 @@ class ModelEMA:
196
  v += (1. - d) * msd[k].detach()
197
 
198
  def update_attr(self, model):
199
- # Assign attributes (which may change during training)
200
- for k in model.__dict__.keys():
201
- if not k.startswith('_'):
202
- setattr(self.ema, k, getattr(model, k))
 
 
54
  return time.time()
55
 
56
 
57
+ def is_parallel(model):
58
+ # is model is parallel with DP or DDP
59
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
60
+
61
+
62
  def initialize_weights(model):
63
  for m in model.modules():
64
  t = type(m)
 
116
 
117
  try: # FLOPS
118
  from thop import profile
119
+ flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
120
+ fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
121
  except:
122
  fs = ''
123
 
 
190
  self.updates += 1
191
  d = self.decay(self.updates)
192
  with torch.no_grad():
193
+ if is_parallel(model):
194
  msd, esd = model.module.state_dict(), self.ema.module.state_dict()
195
  else:
196
  msd, esd = model.state_dict(), self.ema.state_dict()
 
201
  v += (1. - d) * msd[k].detach()
202
 
203
  def update_attr(self, model):
204
+ # Update class attributes
205
+ ema = self.ema.module if is_parallel(model) else self.ema
206
+ for k, v in model.__dict__.items():
207
+ if not k.startswith('_') and k != 'module':
208
+ setattr(ema, k, v)