Spaces:
Running
Running
Update landmarkdiff/model_registry.py to v0.3.2
Browse files- landmarkdiff/model_registry.py +25 -10
landmarkdiff/model_registry.py
CHANGED
|
@@ -139,7 +139,9 @@ class ModelRegistry:
|
|
| 139 |
step = int(parts[-1])
|
| 140 |
|
| 141 |
# Compute size
|
| 142 |
-
size_mb = sum(
|
|
|
|
|
|
|
| 143 |
|
| 144 |
return ModelEntry(
|
| 145 |
name=ckpt_dir.name,
|
|
@@ -193,15 +195,16 @@ class ModelRegistry:
|
|
| 193 |
Returns:
|
| 194 |
Best ModelEntry, or None if no models have the metric.
|
| 195 |
"""
|
| 196 |
-
candidates = [
|
|
|
|
|
|
|
|
|
|
| 197 |
if not candidates:
|
| 198 |
return None
|
| 199 |
|
| 200 |
-
return (
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
else max(candidates, key=lambda m: m.metrics[metric])
|
| 204 |
-
)
|
| 205 |
|
| 206 |
def get_by_step(self, step: int) -> ModelEntry | None:
|
| 207 |
"""Get a model by its training step."""
|
|
@@ -246,30 +249,40 @@ class ModelRegistry:
|
|
| 246 |
self,
|
| 247 |
name: str,
|
| 248 |
use_ema: bool = True,
|
|
|
|
| 249 |
) -> Any:
|
| 250 |
"""Load a ControlNet model from checkpoint.
|
| 251 |
|
| 252 |
Args:
|
| 253 |
name: Checkpoint name.
|
| 254 |
use_ema: If True, load EMA weights (preferred for inference).
|
|
|
|
|
|
|
| 255 |
|
| 256 |
Returns:
|
| 257 |
ControlNetModel instance.
|
| 258 |
"""
|
| 259 |
from diffusers import ControlNetModel
|
| 260 |
|
|
|
|
|
|
|
|
|
|
| 261 |
entry = self._models.get(name)
|
| 262 |
if entry is None:
|
| 263 |
raise KeyError(f"Checkpoint '{name}' not found in registry")
|
| 264 |
|
| 265 |
if use_ema and entry.has_ema:
|
| 266 |
-
return ControlNetModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
# Fallback: load from training state
|
| 269 |
state = self.load(name)
|
| 270 |
model = ControlNetModel.from_pretrained(
|
| 271 |
-
"
|
| 272 |
subfolder="diffusion_sd15",
|
|
|
|
| 273 |
)
|
| 274 |
key = "ema_controlnet" if use_ema else "controlnet"
|
| 275 |
model.load_state_dict(state[key])
|
|
@@ -351,7 +364,9 @@ class ModelRegistry:
|
|
| 351 |
for metric in sorted(all_metrics):
|
| 352 |
values = [m.metrics[metric] for m in models if metric in m.metrics]
|
| 353 |
if values:
|
| 354 |
-
lines.append(
|
|
|
|
|
|
|
| 355 |
|
| 356 |
return "\n".join(lines)
|
| 357 |
|
|
|
|
| 139 |
step = int(parts[-1])
|
| 140 |
|
| 141 |
# Compute size
|
| 142 |
+
size_mb = sum(
|
| 143 |
+
f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
|
| 144 |
+
) / (1024 * 1024)
|
| 145 |
|
| 146 |
return ModelEntry(
|
| 147 |
name=ckpt_dir.name,
|
|
|
|
| 195 |
Returns:
|
| 196 |
Best ModelEntry, or None if no models have the metric.
|
| 197 |
"""
|
| 198 |
+
candidates = [
|
| 199 |
+
m for m in self._models.values()
|
| 200 |
+
if metric in m.metrics
|
| 201 |
+
]
|
| 202 |
if not candidates:
|
| 203 |
return None
|
| 204 |
|
| 205 |
+
return min(candidates, key=lambda m: m.metrics[metric]) \
|
| 206 |
+
if lower_is_better else \
|
| 207 |
+
max(candidates, key=lambda m: m.metrics[metric])
|
|
|
|
|
|
|
| 208 |
|
| 209 |
def get_by_step(self, step: int) -> ModelEntry | None:
|
| 210 |
"""Get a model by its training step."""
|
|
|
|
| 249 |
self,
|
| 250 |
name: str,
|
| 251 |
use_ema: bool = True,
|
| 252 |
+
torch_dtype: torch.dtype | None = None,
|
| 253 |
) -> Any:
|
| 254 |
"""Load a ControlNet model from checkpoint.
|
| 255 |
|
| 256 |
Args:
|
| 257 |
name: Checkpoint name.
|
| 258 |
use_ema: If True, load EMA weights (preferred for inference).
|
| 259 |
+
torch_dtype: Weight dtype (e.g. torch.float16). Defaults to
|
| 260 |
+
float16 on CUDA, float32 on CPU.
|
| 261 |
|
| 262 |
Returns:
|
| 263 |
ControlNetModel instance.
|
| 264 |
"""
|
| 265 |
from diffusers import ControlNetModel
|
| 266 |
|
| 267 |
+
if torch_dtype is None:
|
| 268 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 269 |
+
|
| 270 |
entry = self._models.get(name)
|
| 271 |
if entry is None:
|
| 272 |
raise KeyError(f"Checkpoint '{name}' not found in registry")
|
| 273 |
|
| 274 |
if use_ema and entry.has_ema:
|
| 275 |
+
return ControlNetModel.from_pretrained(
|
| 276 |
+
str(entry.path / "controlnet_ema"),
|
| 277 |
+
torch_dtype=torch_dtype,
|
| 278 |
+
)
|
| 279 |
|
| 280 |
# Fallback: load from training state
|
| 281 |
state = self.load(name)
|
| 282 |
model = ControlNetModel.from_pretrained(
|
| 283 |
+
"CrucibleAI/ControlNetMediaPipeFace",
|
| 284 |
subfolder="diffusion_sd15",
|
| 285 |
+
torch_dtype=torch_dtype,
|
| 286 |
)
|
| 287 |
key = "ema_controlnet" if use_ema else "controlnet"
|
| 288 |
model.load_state_dict(state[key])
|
|
|
|
| 364 |
for metric in sorted(all_metrics):
|
| 365 |
values = [m.metrics[metric] for m in models if metric in m.metrics]
|
| 366 |
if values:
|
| 367 |
+
lines.append(
|
| 368 |
+
f" {metric}: {min(values):.4f} — {max(values):.4f}"
|
| 369 |
+
)
|
| 370 |
|
| 371 |
return "\n".join(lines)
|
| 372 |
|