dreamlessx commited on
Commit
387e567
·
verified ·
1 Parent(s): b163477

Upload landmarkdiff/model_registry.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/model_registry.py +369 -0
landmarkdiff/model_registry.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model registry for checkpoint discovery and management.
2
+
3
+ Provides a unified interface for finding, loading, and comparing model
4
+ checkpoints across local directories and remote sources.
5
+
6
+ Usage:
7
+ from landmarkdiff.model_registry import ModelRegistry
8
+
9
+ registry = ModelRegistry("checkpoints/")
10
+
11
+ # Discover all checkpoints
12
+ models = registry.list_models()
13
+
14
+ # Get best checkpoint by metric
15
+ best = registry.get_best("loss")
16
+
17
+ # Load a specific checkpoint
18
+ state = registry.load("checkpoint-5000")
19
+
20
+ # Compare multiple checkpoints
21
+ comparison = registry.compare(["checkpoint-1000", "checkpoint-5000"])
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any
30
+
31
+ import torch
32
+
33
+
34
+ @dataclass
35
+ class ModelEntry:
36
+ """Metadata for a registered model checkpoint."""
37
+
38
+ name: str
39
+ path: Path
40
+ step: int = 0
41
+ phase: str = ""
42
+ metrics: dict[str, float] = field(default_factory=dict)
43
+ size_mb: float = 0.0
44
+ has_ema: bool = False
45
+ has_training_state: bool = False
46
+
47
+ @property
48
+ def inference_path(self) -> Path | None:
49
+ """Path to inference-ready weights (EMA preferred)."""
50
+ ema_dir = self.path / "controlnet_ema"
51
+ if ema_dir.exists():
52
+ return ema_dir
53
+ # Fallback to training state
54
+ state_path = self.path / "training_state.pt"
55
+ if state_path.exists():
56
+ return state_path
57
+ return None
58
+
59
+
60
+ class ModelRegistry:
61
+ """Central registry for discovering and managing model checkpoints.
62
+
63
+ Args:
64
+ checkpoint_dirs: One or more directories to scan for checkpoints.
65
+ scan_on_init: Whether to scan directories immediately on creation.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ *checkpoint_dirs: str | Path,
71
+ scan_on_init: bool = True,
72
+ ) -> None:
73
+ self.checkpoint_dirs = [Path(d) for d in checkpoint_dirs]
74
+ self._models: dict[str, ModelEntry] = {}
75
+
76
+ if scan_on_init:
77
+ self.scan()
78
+
79
+ def scan(self) -> int:
80
+ """Scan checkpoint directories and register all found models.
81
+
82
+ Returns:
83
+ Number of models found.
84
+ """
85
+ self._models.clear()
86
+ for base_dir in self.checkpoint_dirs:
87
+ if not base_dir.exists():
88
+ continue
89
+ self._scan_directory(base_dir)
90
+ return len(self._models)
91
+
92
+ def _scan_directory(self, base_dir: Path) -> None:
93
+ """Scan a single directory for checkpoint subdirectories."""
94
+ # Look for checkpoint-* directories
95
+ for ckpt_dir in sorted(base_dir.glob("checkpoint-*")):
96
+ if not ckpt_dir.is_dir():
97
+ continue
98
+ entry = self._load_entry(ckpt_dir)
99
+ if entry is not None:
100
+ self._models[entry.name] = entry
101
+
102
+ # Also check for "final" and "best" directories/symlinks
103
+ for special in ["final", "best", "latest"]:
104
+ special_dir = base_dir / special
105
+ if special_dir.exists() and special_dir.is_dir():
106
+ entry = self._load_entry(special_dir)
107
+ if entry is not None:
108
+ entry.name = f"{base_dir.name}/{special}"
109
+ self._models[entry.name] = entry
110
+
111
+ def _load_entry(self, ckpt_dir: Path) -> ModelEntry | None:
112
+ """Load metadata for a single checkpoint directory."""
113
+ has_training = (ckpt_dir / "training_state.pt").exists()
114
+ has_ema = (ckpt_dir / "controlnet_ema").exists()
115
+
116
+ if not has_training and not has_ema:
117
+ return None
118
+
119
+ # Try to load metadata.json (from CheckpointManager)
120
+ meta_path = ckpt_dir / "metadata.json"
121
+ if meta_path.exists():
122
+ with open(meta_path) as f:
123
+ meta = json.load(f)
124
+ return ModelEntry(
125
+ name=ckpt_dir.name,
126
+ path=ckpt_dir,
127
+ step=meta.get("step", 0),
128
+ phase=meta.get("phase", ""),
129
+ metrics=meta.get("metrics", {}),
130
+ size_mb=meta.get("size_mb", 0.0),
131
+ has_ema=has_ema,
132
+ has_training_state=has_training,
133
+ )
134
+
135
+ # Fallback: extract step from directory name
136
+ step = 0
137
+ parts = ckpt_dir.name.split("-")
138
+ if len(parts) >= 2 and parts[-1].isdigit():
139
+ step = int(parts[-1])
140
+
141
+ # Compute size
142
+ size_mb = sum(
143
+ f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
144
+ ) / (1024 * 1024)
145
+
146
+ return ModelEntry(
147
+ name=ckpt_dir.name,
148
+ path=ckpt_dir,
149
+ step=step,
150
+ size_mb=round(size_mb, 1),
151
+ has_ema=has_ema,
152
+ has_training_state=has_training,
153
+ )
154
+
155
+ # ------------------------------------------------------------------
156
+ # Queries
157
+ # ------------------------------------------------------------------
158
+
159
+ def list_models(self, sort_by: str = "step") -> list[ModelEntry]:
160
+ """List all registered models.
161
+
162
+ Args:
163
+ sort_by: Sort key — "step", "name", or a metric name.
164
+
165
+ Returns:
166
+ Sorted list of ModelEntry objects.
167
+ """
168
+ models = list(self._models.values())
169
+ if sort_by == "step":
170
+ models.sort(key=lambda m: m.step)
171
+ elif sort_by == "name":
172
+ models.sort(key=lambda m: m.name)
173
+ else:
174
+ # Sort by metric value
175
+ models.sort(
176
+ key=lambda m: m.metrics.get(sort_by, float("inf")),
177
+ )
178
+ return models
179
+
180
+ def get(self, name: str) -> ModelEntry | None:
181
+ """Get a model entry by name."""
182
+ return self._models.get(name)
183
+
184
+ def get_best(
185
+ self,
186
+ metric: str = "loss",
187
+ lower_is_better: bool = True,
188
+ ) -> ModelEntry | None:
189
+ """Get the best model by a specific metric.
190
+
191
+ Args:
192
+ metric: Metric name to rank by.
193
+ lower_is_better: If True, lower values are better.
194
+
195
+ Returns:
196
+ Best ModelEntry, or None if no models have the metric.
197
+ """
198
+ candidates = [
199
+ m for m in self._models.values()
200
+ if metric in m.metrics
201
+ ]
202
+ if not candidates:
203
+ return None
204
+
205
+ return min(candidates, key=lambda m: m.metrics[metric]) \
206
+ if lower_is_better else \
207
+ max(candidates, key=lambda m: m.metrics[metric])
208
+
209
+ def get_by_step(self, step: int) -> ModelEntry | None:
210
+ """Get a model by its training step."""
211
+ for model in self._models.values():
212
+ if model.step == step:
213
+ return model
214
+ return None
215
+
216
+ # ------------------------------------------------------------------
217
+ # Loading
218
+ # ------------------------------------------------------------------
219
+
220
+ def load(
221
+ self,
222
+ name: str,
223
+ map_location: str = "cpu",
224
+ ) -> dict[str, Any]:
225
+ """Load training state from a checkpoint.
226
+
227
+ Args:
228
+ name: Checkpoint name (e.g. "checkpoint-5000").
229
+ map_location: Device to load tensors to.
230
+
231
+ Returns:
232
+ State dict containing controlnet, ema_controlnet, optimizer, etc.
233
+
234
+ Raises:
235
+ KeyError: If checkpoint not found.
236
+ FileNotFoundError: If training_state.pt missing.
237
+ """
238
+ entry = self._models.get(name)
239
+ if entry is None:
240
+ raise KeyError(f"Checkpoint '{name}' not found in registry")
241
+
242
+ state_path = entry.path / "training_state.pt"
243
+ if not state_path.exists():
244
+ raise FileNotFoundError(f"No training_state.pt in {entry.path}")
245
+
246
+ return torch.load(state_path, map_location=map_location, weights_only=True)
247
+
248
+ def load_controlnet(
249
+ self,
250
+ name: str,
251
+ use_ema: bool = True,
252
+ ) -> Any:
253
+ """Load a ControlNet model from checkpoint.
254
+
255
+ Args:
256
+ name: Checkpoint name.
257
+ use_ema: If True, load EMA weights (preferred for inference).
258
+
259
+ Returns:
260
+ ControlNetModel instance.
261
+ """
262
+ from diffusers import ControlNetModel
263
+
264
+ entry = self._models.get(name)
265
+ if entry is None:
266
+ raise KeyError(f"Checkpoint '{name}' not found in registry")
267
+
268
+ if use_ema and entry.has_ema:
269
+ return ControlNetModel.from_pretrained(
270
+ str(entry.path / "controlnet_ema")
271
+ )
272
+
273
+ # Fallback: load from training state
274
+ state = self.load(name)
275
+ model = ControlNetModel.from_pretrained(
276
+ "lllyasviel/control_v11p_sd15_openpose",
277
+ subfolder="diffusion_sd15",
278
+ )
279
+ key = "ema_controlnet" if use_ema else "controlnet"
280
+ model.load_state_dict(state[key])
281
+ return model
282
+
283
+ # ------------------------------------------------------------------
284
+ # Comparison
285
+ # ------------------------------------------------------------------
286
+
287
+ def compare(
288
+ self,
289
+ names: list[str],
290
+ metrics: list[str] | None = None,
291
+ ) -> dict[str, Any]:
292
+ """Compare multiple checkpoints side-by-side.
293
+
294
+ Args:
295
+ names: List of checkpoint names to compare.
296
+ metrics: Specific metrics to include. None = all available.
297
+
298
+ Returns:
299
+ Dict with comparison table data.
300
+ """
301
+ entries = []
302
+ for name in names:
303
+ entry = self._models.get(name)
304
+ if entry is not None:
305
+ entries.append(entry)
306
+
307
+ if not entries:
308
+ return {"error": "No valid checkpoints found"}
309
+
310
+ # Collect all available metrics
311
+ if metrics is None:
312
+ all_metrics: set[str] = set()
313
+ for e in entries:
314
+ all_metrics.update(e.metrics.keys())
315
+ metrics = sorted(all_metrics)
316
+
317
+ rows = []
318
+ for e in entries:
319
+ row = {
320
+ "name": e.name,
321
+ "step": e.step,
322
+ "phase": e.phase,
323
+ "size_mb": e.size_mb,
324
+ }
325
+ for m in metrics:
326
+ row[m] = e.metrics.get(m)
327
+ rows.append(row)
328
+
329
+ return {
330
+ "metrics": metrics,
331
+ "rows": rows,
332
+ "count": len(rows),
333
+ }
334
+
335
+ # ------------------------------------------------------------------
336
+ # Summary
337
+ # ------------------------------------------------------------------
338
+
339
+ def summary(self) -> str:
340
+ """Return a human-readable summary."""
341
+ models = self.list_models()
342
+ if not models:
343
+ return "No models registered."
344
+
345
+ total_size = sum(m.size_mb for m in models)
346
+ lines = [
347
+ f"Model Registry: {len(models)} checkpoints ({total_size:.0f} MB)",
348
+ f" Steps: {models[0].step} — {models[-1].step}",
349
+ ]
350
+
351
+ # Show metrics ranges
352
+ all_metrics: set[str] = set()
353
+ for m in models:
354
+ all_metrics.update(m.metrics.keys())
355
+
356
+ for metric in sorted(all_metrics):
357
+ values = [m.metrics[metric] for m in models if metric in m.metrics]
358
+ if values:
359
+ lines.append(
360
+ f" {metric}: {min(values):.4f} — {max(values):.4f}"
361
+ )
362
+
363
+ return "\n".join(lines)
364
+
365
+ def __len__(self) -> int:
366
+ return len(self._models)
367
+
368
+ def __contains__(self, name: str) -> bool:
369
+ return name in self._models