glenn-jocher commited on
Commit
e71fd0e
1 Parent(s): 9ae8683

Model freeze capability (#679)

Browse files
Files changed (1) hide show
  1. train.py +9 -1
train.py CHANGED
@@ -73,6 +73,14 @@ def train(hyp, opt, device, tb_writer=None):
73
  else:
74
  model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
75
 
 
 
 
 
 
 
 
 
76
  # Optimizer
77
  nbs = 64 # nominal batch size
78
  accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
@@ -125,7 +133,7 @@ def train(hyp, opt, device, tb_writer=None):
125
  epochs += ckpt['epoch'] # finetune additional epochs
126
 
127
  del ckpt, state_dict
128
-
129
  # Image sizes
130
  gs = int(max(model.stride)) # grid size (max stride)
131
  imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
 
73
  else:
74
  model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
75
 
76
+ # Freeze
77
+ freeze = ['', ] # parameter names to freeze (full or partial)
78
+ if any(freeze):
79
+ for k, v in model.named_parameters():
80
+ if any(x in k for x in freeze):
81
+ print('freezing %s' % k)
82
+ v.requires_grad = False
83
+
84
  # Optimizer
85
  nbs = 64 # nominal batch size
86
  accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
 
133
  epochs += ckpt['epoch'] # finetune additional epochs
134
 
135
  del ckpt, state_dict
136
+
137
  # Image sizes
138
  gs = int(max(model.stride)) # grid size (max stride)
139
  imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples