glenn-jocher
commited on
Commit
•
cce95e7
1
Parent(s):
d9b64c2
backbone as FP16, save default to FP32
Browse files- train.py +1 -1
- utils/utils.py +1 -2
train.py
CHANGED
@@ -332,7 +332,7 @@ def train(hyp):
|
|
332 |
ckpt = {'epoch': epoch,
|
333 |
'best_fitness': best_fitness,
|
334 |
'training_results': f.read(),
|
335 |
-
'model': ema.ema.module
|
336 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
337 |
|
338 |
# Save last, best and delete
|
|
|
332 |
ckpt = {'epoch': epoch,
|
333 |
'best_fitness': best_fitness,
|
334 |
'training_results': f.read(),
|
335 |
+
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
336 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
337 |
|
338 |
# Save last, best and delete
|
utils/utils.py
CHANGED
@@ -627,13 +627,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
|
|
627 |
def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
|
628 |
# create backbone 's' from 'f'
|
629 |
device = torch.device('cpu')
|
630 |
-
x = torch.load(f, map_location=device)
|
631 |
-
torch.save(x, s) # update model if SourceChangeWarning
|
632 |
x = torch.load(s, map_location=device)
|
633 |
|
634 |
x['optimizer'] = None
|
635 |
x['training_results'] = None
|
636 |
x['epoch'] = -1
|
|
|
637 |
for p in x['model'].parameters():
|
638 |
p.requires_grad = True
|
639 |
torch.save(x, s)
|
|
|
627 |
def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
|
628 |
# create backbone 's' from 'f'
|
629 |
device = torch.device('cpu')
|
|
|
|
|
630 |
x = torch.load(s, map_location=device)
|
631 |
|
632 |
x['optimizer'] = None
|
633 |
x['training_results'] = None
|
634 |
x['epoch'] = -1
|
635 |
+
x['model'].half() # to FP16
|
636 |
for p in x['model'].parameters():
|
637 |
p.requires_grad = True
|
638 |
torch.save(x, s)
|