Braien's picture
Upload folder using huggingface_hub
900b898 verified
import torch
import os
class CheckpointLoader:
def __init__(self, load_dir, device="cuda"):
self.load_dir = load_dir
self.device = device
def get_latest_checkpoint(self):
"""Megkeresi a legfrissebb mentést idő alapján."""
if not os.path.exists(self.load_dir):
return None
files = [
os.path.join(self.load_dir, f)
for f in os.listdir(self.load_dir)
if f.endswith(".pt") and f.startswith("step_")
]
if not files:
return None
# Idő szerint rendezés (legutolsó a legfrissebb)
files.sort(key=os.path.getmtime)
return files[-1]
def load_latest(self, model, optimizer=None, scheduler=None):
"""
Betölti a legutolsó checkpointot a modellbe és az optimizerbe.
Visszaadja a (state_dict, filename) párt.
"""
path = self.get_latest_checkpoint()
if not path:
print("[INFO] Nem találtam checkpointot, tiszta lappal indul a tanítás.")
return None, None
print(f"[INFO] Checkpoint betöltése: {path}")
filename = os.path.basename(path)
try:
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
# 1. Model State
sd = None
if "model_state_dict" in checkpoint:
sd = checkpoint["model_state_dict"]
elif "trainer" in checkpoint:
sd = checkpoint["trainer"]
else:
sd = checkpoint
was_skipped = False
if sd:
# Szűrés méretbeli eltérésre
model_sd = model.state_dict()
filtered_sd = {}
for k, v in sd.items():
if k in model_sd:
if v.shape == model_sd[k].shape:
filtered_sd[k] = v
else:
print(
f"[WARN] Méretbeli eltérés, kihagyom: {k} ({v.shape} vs {model_sd[k].shape})"
)
was_skipped = True
else:
filtered_sd[k] = v
model.load_state_dict(filtered_sd, strict=False)
print("[OK] Modell súlyok betöltve.")
# 2. Optimizer State - CSAK ha nem volt architektúra váltás!
if optimizer and "optimizer_state_dict" in checkpoint:
if was_skipped:
print(
"[INFO] Architektúra váltást észleltem, az Optimizer állapota nem kerül betöltésre (tiszta indítás)."
)
else:
try:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
print("[OK] Optimizer állapot betöltve.")
except Exception as e:
print(f"[WARN] Optimizer betöltése sikertelen: {e}")
# 3. Scheduler State (ha van és kértük)
if scheduler and "scheduler_state_dict" in checkpoint:
try:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
print("[OK] Scheduler állapot betöltve.")
except Exception as e:
print(f"[WARN] Scheduler betöltése sikertelen: {e}")
return checkpoint, filename
except Exception as e:
print(f"[HIBA] Checkpoint betöltése sikertelen: {e}")
return None, None