dreamlessx commited on
Commit
30cc2b8
·
verified ·
1 Parent(s): 0434bde

Update landmarkdiff/model_registry.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/model_registry.py +25 -10
landmarkdiff/model_registry.py CHANGED
@@ -139,7 +139,9 @@ class ModelRegistry:
139
  step = int(parts[-1])
140
 
141
  # Compute size
142
- size_mb = sum(f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()) / (1024 * 1024)
 
 
143
 
144
  return ModelEntry(
145
  name=ckpt_dir.name,
@@ -193,15 +195,16 @@ class ModelRegistry:
193
  Returns:
194
  Best ModelEntry, or None if no models have the metric.
195
  """
196
- candidates = [m for m in self._models.values() if metric in m.metrics]
 
 
 
197
  if not candidates:
198
  return None
199
 
200
- return (
201
- min(candidates, key=lambda m: m.metrics[metric])
202
- if lower_is_better
203
- else max(candidates, key=lambda m: m.metrics[metric])
204
- )
205
 
206
  def get_by_step(self, step: int) -> ModelEntry | None:
207
  """Get a model by its training step."""
@@ -246,30 +249,40 @@ class ModelRegistry:
246
  self,
247
  name: str,
248
  use_ema: bool = True,
 
249
  ) -> Any:
250
  """Load a ControlNet model from checkpoint.
251
 
252
  Args:
253
  name: Checkpoint name.
254
  use_ema: If True, load EMA weights (preferred for inference).
 
 
255
 
256
  Returns:
257
  ControlNetModel instance.
258
  """
259
  from diffusers import ControlNetModel
260
 
 
 
 
261
  entry = self._models.get(name)
262
  if entry is None:
263
  raise KeyError(f"Checkpoint '{name}' not found in registry")
264
 
265
  if use_ema and entry.has_ema:
266
- return ControlNetModel.from_pretrained(str(entry.path / "controlnet_ema"))
 
 
 
267
 
268
  # Fallback: load from training state
269
  state = self.load(name)
270
  model = ControlNetModel.from_pretrained(
271
- "lllyasviel/control_v11p_sd15_openpose",
272
  subfolder="diffusion_sd15",
 
273
  )
274
  key = "ema_controlnet" if use_ema else "controlnet"
275
  model.load_state_dict(state[key])
@@ -351,7 +364,9 @@ class ModelRegistry:
351
  for metric in sorted(all_metrics):
352
  values = [m.metrics[metric] for m in models if metric in m.metrics]
353
  if values:
354
- lines.append(f" {metric}: {min(values):.4f} — {max(values):.4f}")
 
 
355
 
356
  return "\n".join(lines)
357
 
 
139
  step = int(parts[-1])
140
 
141
  # Compute size
142
+ size_mb = sum(
143
+ f.stat().st_size for f in ckpt_dir.rglob("*") if f.is_file()
144
+ ) / (1024 * 1024)
145
 
146
  return ModelEntry(
147
  name=ckpt_dir.name,
 
195
  Returns:
196
  Best ModelEntry, or None if no models have the metric.
197
  """
198
+ candidates = [
199
+ m for m in self._models.values()
200
+ if metric in m.metrics
201
+ ]
202
  if not candidates:
203
  return None
204
 
205
+ return min(candidates, key=lambda m: m.metrics[metric]) \
206
+ if lower_is_better else \
207
+ max(candidates, key=lambda m: m.metrics[metric])
 
 
208
 
209
  def get_by_step(self, step: int) -> ModelEntry | None:
210
  """Get a model by its training step."""
 
249
  self,
250
  name: str,
251
  use_ema: bool = True,
252
+ torch_dtype: torch.dtype | None = None,
253
  ) -> Any:
254
  """Load a ControlNet model from checkpoint.
255
 
256
  Args:
257
  name: Checkpoint name.
258
  use_ema: If True, load EMA weights (preferred for inference).
259
+ torch_dtype: Weight dtype (e.g. torch.float16). Defaults to
260
+ float16 on CUDA, float32 on CPU.
261
 
262
  Returns:
263
  ControlNetModel instance.
264
  """
265
  from diffusers import ControlNetModel
266
 
267
+ if torch_dtype is None:
268
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
269
+
270
  entry = self._models.get(name)
271
  if entry is None:
272
  raise KeyError(f"Checkpoint '{name}' not found in registry")
273
 
274
  if use_ema and entry.has_ema:
275
+ return ControlNetModel.from_pretrained(
276
+ str(entry.path / "controlnet_ema"),
277
+ torch_dtype=torch_dtype,
278
+ )
279
 
280
  # Fallback: load from training state
281
  state = self.load(name)
282
  model = ControlNetModel.from_pretrained(
283
+ "CrucibleAI/ControlNetMediaPipeFace",
284
  subfolder="diffusion_sd15",
285
+ torch_dtype=torch_dtype,
286
  )
287
  key = "ema_controlnet" if use_ema else "controlnet"
288
  model.load_state_dict(state[key])
 
364
  for metric in sorted(all_metrics):
365
  values = [m.metrics[metric] for m in models if metric in m.metrics]
366
  if values:
367
+ lines.append(
368
+ f" {metric}: {min(values):.4f} — {max(values):.4f}"
369
+ )
370
 
371
  return "\n".join(lines)
372