glenn-jocher
commited on
Commit
•
a586751
1
Parent(s):
5de4e25
multi-gpu ckpt filesize bug fix #253
Browse files- train.py +1 -1
- utils/torch_utils.py +11 -12
train.py
CHANGED
@@ -287,7 +287,7 @@ def train(hyp):
|
|
287 |
scheduler.step()
|
288 |
|
289 |
# mAP
|
290 |
-
ema.update_attr(model)
|
291 |
final_epoch = epoch + 1 == epochs
|
292 |
if not opt.notest or final_epoch: # Calculate mAP
|
293 |
results, maps, times = test.test(opt.data,
|
|
|
287 |
scheduler.step()
|
288 |
|
289 |
# mAP
|
290 |
+
ema.update_attr(model, include=['md', 'nc', 'hyp', 'names', 'stride'])
|
291 |
final_epoch = epoch + 1 == epochs
|
292 |
if not opt.notest or final_epoch: # Calculate mAP
|
293 |
results, maps, times = test.test(opt.data,
|
utils/torch_utils.py
CHANGED
@@ -173,22 +173,23 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
|
|
173 |
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
class ModelEMA:
|
177 |
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
178 |
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
179 |
This is intended to allow functionality like
|
180 |
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
181 |
A smoothed version of the weights is necessary for some training schemes to perform well.
|
182 |
-
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
183 |
-
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
184 |
-
smoothing of weights to match results. Pay attention to the decay constant you are using
|
185 |
-
relative to your update count per epoch.
|
186 |
-
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
187 |
-
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
188 |
-
process, or after the training stops converging.
|
189 |
This class is sensitive where it is initialized in the sequence of model init,
|
190 |
GPU assignment and distributed training wrappers.
|
191 |
-
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
192 |
"""
|
193 |
|
194 |
def __init__(self, model, decay=0.9999, updates=0):
|
@@ -211,8 +212,6 @@ class ModelEMA:
|
|
211 |
v *= d
|
212 |
v += (1. - d) * msd[k].detach()
|
213 |
|
214 |
-
def update_attr(self, model):
|
215 |
# Update EMA attributes
|
216 |
-
|
217 |
-
if not k.startswith('_') and k not in ["process_group", "reducer"]:
|
218 |
-
setattr(self.ema, k, v)
|
|
|
173 |
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
174 |
|
175 |
|
176 |
+
def copy_attr(a, b, include=(), exclude=()):
|
177 |
+
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
178 |
+
for k, v in b.__dict__.items():
|
179 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
180 |
+
continue
|
181 |
+
else:
|
182 |
+
setattr(a, k, v)
|
183 |
+
|
184 |
+
|
185 |
class ModelEMA:
|
186 |
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
187 |
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
188 |
This is intended to allow functionality like
|
189 |
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
190 |
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
This class is sensitive where it is initialized in the sequence of model init,
|
192 |
GPU assignment and distributed training wrappers.
|
|
|
193 |
"""
|
194 |
|
195 |
def __init__(self, model, decay=0.9999, updates=0):
|
|
|
212 |
v *= d
|
213 |
v += (1. - d) * msd[k].detach()
|
214 |
|
215 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
216 |
# Update EMA attributes
|
217 |
+
copy_attr(self.ema, model, include, exclude)
|
|
|
|