fix: reset lr on load checkpoint
Browse files- detector/model.py +7 -0
detector/model.py
CHANGED
@@ -156,6 +156,7 @@ class FontDetector(ptl.LightningModule):
|
|
156 |
self.betas = betas
|
157 |
self.num_warmup_iters = num_warmup_iters
|
158 |
self.num_iters = num_iters
|
|
|
159 |
|
160 |
def forward(self, x):
|
161 |
return self.model(x)
|
@@ -239,6 +240,9 @@ class FontDetector(ptl.LightningModule):
|
|
239 |
self.scheduler = CosineWarmupScheduler(
|
240 |
optimizer, self.num_warmup_iters, self.num_iters
|
241 |
)
|
|
|
|
|
|
|
242 |
return optimizer
|
243 |
|
244 |
def optimizer_step(
|
@@ -255,3 +259,6 @@ class FontDetector(ptl.LightningModule):
|
|
255 |
)
|
256 |
self.log("lr", self.scheduler.get_last_lr()[0])
|
257 |
self.scheduler.step()
|
|
|
|
|
|
|
|
156 |
self.betas = betas
|
157 |
self.num_warmup_iters = num_warmup_iters
|
158 |
self.num_iters = num_iters
|
159 |
+
self.load_step = 0
|
160 |
|
161 |
def forward(self, x):
|
162 |
return self.model(x)
|
|
|
240 |
self.scheduler = CosineWarmupScheduler(
|
241 |
optimizer, self.num_warmup_iters, self.num_iters
|
242 |
)
|
243 |
+
for _ in range(self.load_step):
|
244 |
+
self.scheduler.step()
|
245 |
+
print("Current learning rate set to:", self.scheduler.get_last_lr())
|
246 |
return optimizer
|
247 |
|
248 |
def optimizer_step(
|
|
|
259 |
)
|
260 |
self.log("lr", self.scheduler.get_last_lr()[0])
|
261 |
self.scheduler.step()
|
262 |
+
|
263 |
+
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
264 |
+
self.load_step = checkpoint["global_step"]
|