dreamlessx commited on
Commit
41f1384
·
verified ·
1 Parent(s): 893c358

Update landmarkdiff/checkpoint_manager.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/checkpoint_manager.py +28 -11
landmarkdiff/checkpoint_manager.py CHANGED
@@ -106,6 +106,19 @@ class CheckpointManager:
106
  self._index = json.load(f)
107
  if "checkpoints" not in self._index:
108
  self._index["checkpoints"] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def _save_index(self) -> None:
111
  with open(self._index_path(), "w") as f:
@@ -166,7 +179,9 @@ class CheckpointManager:
166
  torch.save(state, ckpt_dir / "training_state.pt")
167
 
168
  # Compute checkpoint size
169
- size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
 
 
170
 
171
  # Create metadata
172
  meta = CheckpointMetadata(
@@ -214,7 +229,7 @@ class CheckpointManager:
214
  entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
215
 
216
  # Mark best
217
- best_names = {e[0] for e in entries[: self.keep_best]}
218
  for name, meta in self._index["checkpoints"].items():
219
  meta["is_best"] = name in best_names
220
 
@@ -243,11 +258,11 @@ class CheckpointManager:
243
  val = meta.get("metrics", {}).get(self.metric)
244
  if val is None:
245
  continue
246
- if (
247
- best_val is None
248
- or (self.lower_is_better and val < best_val)
249
- or (not self.lower_is_better and val > best_val)
250
- ):
251
  best, best_val = name, val
252
  return best
253
 
@@ -278,7 +293,7 @@ class CheckpointManager:
278
  keep = set()
279
 
280
  # Keep latest
281
- for name in all_names[-self.keep_latest :]:
282
  keep.add(name)
283
 
284
  # Keep best
@@ -292,7 +307,7 @@ class CheckpointManager:
292
  ckpt_dir = self.output_dir / name
293
  if ckpt_dir.exists():
294
  shutil.rmtree(ckpt_dir)
295
- del self._index["checkpoints"][name]
296
 
297
  self._save_index()
298
 
@@ -321,7 +336,10 @@ class CheckpointManager:
321
 
322
  def total_size_mb(self) -> float:
323
  """Return total disk size of all tracked checkpoints."""
324
- return sum(meta.get("size_mb", 0.0) for meta in self._index["checkpoints"].values())
 
 
 
325
 
326
  def summary(self) -> str:
327
  """Return a human-readable summary of checkpoint state."""
@@ -346,7 +364,6 @@ class CheckpointManager:
346
  # Helpers
347
  # ------------------------------------------------------------------
348
 
349
-
350
  def _get_state_dict(module: torch.nn.Module) -> dict:
351
  """Extract state dict, handling DDP wrapper."""
352
  if hasattr(module, "module"):
 
106
  self._index = json.load(f)
107
  if "checkpoints" not in self._index:
108
  self._index["checkpoints"] = {}
109
+ # Remove entries whose directories no longer exist on disk
110
+ # (can happen after a crash during pruning)
111
+ missing = [
112
+ name
113
+ for name in list(self._index["checkpoints"])
114
+ if not (self.output_dir / name).exists()
115
+ ]
116
+ if missing:
117
+ for name in missing:
118
+ self._index["checkpoints"].pop(name, None)
119
+ self._update_best()
120
+ self._save_index()
121
+ self._update_symlinks()
122
 
123
  def _save_index(self) -> None:
124
  with open(self._index_path(), "w") as f:
 
179
  torch.save(state, ckpt_dir / "training_state.pt")
180
 
181
  # Compute checkpoint size
182
+ size_mb = sum(
183
+ f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
184
+ ) / (1024 * 1024)
185
 
186
  # Create metadata
187
  meta = CheckpointMetadata(
 
229
  entries.sort(key=lambda x: x[1], reverse=not self.lower_is_better)
230
 
231
  # Mark best
232
+ best_names = {e[0] for e in entries[:self.keep_best]}
233
  for name, meta in self._index["checkpoints"].items():
234
  meta["is_best"] = name in best_names
235
 
 
258
  val = meta.get("metrics", {}).get(self.metric)
259
  if val is None:
260
  continue
261
+ if best_val is None:
262
+ best, best_val = name, val
263
+ elif self.lower_is_better and val < best_val:
264
+ best, best_val = name, val
265
+ elif not self.lower_is_better and val > best_val:
266
  best, best_val = name, val
267
  return best
268
 
 
293
  keep = set()
294
 
295
  # Keep latest
296
+ for name in all_names[-self.keep_latest:]:
297
  keep.add(name)
298
 
299
  # Keep best
 
307
  ckpt_dir = self.output_dir / name
308
  if ckpt_dir.exists():
309
  shutil.rmtree(ckpt_dir)
310
+ self._index["checkpoints"].pop(name, None)
311
 
312
  self._save_index()
313
 
 
336
 
337
  def total_size_mb(self) -> float:
338
  """Return total disk size of all tracked checkpoints."""
339
+ return sum(
340
+ meta.get("size_mb", 0.0)
341
+ for meta in self._index["checkpoints"].values()
342
+ )
343
 
344
  def summary(self) -> str:
345
  """Return a human-readable summary of checkpoint state."""
 
364
  # Helpers
365
  # ------------------------------------------------------------------
366
 
 
367
  def _get_state_dict(module: torch.nn.Module) -> dict:
368
  """Extract state dict, handling DDP wrapper."""
369
  if hasattr(module, "module"):