gyrojeff commited on
Commit
49d8194
1 Parent(s): 8d9c0ef

fix: reset lr on load checkpoint

Browse files
Files changed (1) hide show
  1. detector/model.py +7 -0
detector/model.py CHANGED
@@ -156,6 +156,7 @@ class FontDetector(ptl.LightningModule):
156
  self.betas = betas
157
  self.num_warmup_iters = num_warmup_iters
158
  self.num_iters = num_iters
 
159
 
160
  def forward(self, x):
161
  return self.model(x)
@@ -239,6 +240,9 @@ class FontDetector(ptl.LightningModule):
239
  self.scheduler = CosineWarmupScheduler(
240
  optimizer, self.num_warmup_iters, self.num_iters
241
  )
 
 
 
242
  return optimizer
243
 
244
  def optimizer_step(
@@ -255,3 +259,6 @@ class FontDetector(ptl.LightningModule):
255
  )
256
  self.log("lr", self.scheduler.get_last_lr()[0])
257
  self.scheduler.step()
 
 
 
 
156
  self.betas = betas
157
  self.num_warmup_iters = num_warmup_iters
158
  self.num_iters = num_iters
159
+ self.load_step = 0
160
 
161
  def forward(self, x):
162
  return self.model(x)
 
240
  self.scheduler = CosineWarmupScheduler(
241
  optimizer, self.num_warmup_iters, self.num_iters
242
  )
243
+ for _ in range(self.load_step):
244
+ self.scheduler.step()
245
+ print("Current learning rate set to:", self.scheduler.get_last_lr())
246
  return optimizer
247
 
248
  def optimizer_step(
 
259
  )
260
  self.log("lr", self.scheduler.get_last_lr()[0])
261
  self.scheduler.step()
262
+
263
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
264
+ self.load_step = checkpoint["global_step"]