glenn-jocher
commited on
Commit
•
e71fd0e
1
Parent(s):
9ae8683
Model freeze capability (#679)
Browse files
train.py
CHANGED
@@ -73,6 +73,14 @@ def train(hyp, opt, device, tb_writer=None):
|
|
73 |
else:
|
74 |
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# Optimizer
|
77 |
nbs = 64 # nominal batch size
|
78 |
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
|
@@ -125,7 +133,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
125 |
epochs += ckpt['epoch'] # finetune additional epochs
|
126 |
|
127 |
del ckpt, state_dict
|
128 |
-
|
129 |
# Image sizes
|
130 |
gs = int(max(model.stride)) # grid size (max stride)
|
131 |
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
|
|
|
73 |
else:
|
74 |
model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
|
75 |
|
76 |
+
# Freeze
|
77 |
+
freeze = ['', ] # parameter names to freeze (full or partial)
|
78 |
+
if any(freeze):
|
79 |
+
for k, v in model.named_parameters():
|
80 |
+
if any(x in k for x in freeze):
|
81 |
+
print('freezing %s' % k)
|
82 |
+
v.requires_grad = False
|
83 |
+
|
84 |
# Optimizer
|
85 |
nbs = 64 # nominal batch size
|
86 |
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
|
|
|
133 |
epochs += ckpt['epoch'] # finetune additional epochs
|
134 |
|
135 |
del ckpt, state_dict
|
136 |
+
|
137 |
# Image sizes
|
138 |
gs = int(max(model.stride)) # grid size (max stride)
|
139 |
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
|