dee-z-image / microscope /inspect_model.py
Javad Taghia
Add microscope config-only inspector
9ed2e4d
#!/usr/bin/env python3
"""
Repository model inspector.
This script is designed to work in `--config-only` mode without importing
PyTorch/Diffusers/Transformers. It reads JSON configs from a local Diffusers
repository layout and prints a summary.
With `--params`, it can also compute parameter counts by scanning
`*.safetensors` headers (without loading tensor data into RAM).
"""
import argparse
import json
import math
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
def load_json(path: Path) -> Dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def human_params(value: Optional[int]) -> str:
if value is None:
return "n/a"
if value >= 1_000_000_000:
return f"{value/1e9:.2f}B"
return f"{value/1e6:.2f}M"
def read_model_index(model_dir: Path) -> Dict[str, Any]:
idx_path = model_dir / "model_index.json"
if not idx_path.exists():
return {}
return load_json(idx_path)
def describe_model_index(model_index: Dict[str, Any]) -> None:
if not model_index:
return
print("Pipeline pieces (model_index.json):")
for key, val in model_index.items():
if key.startswith("_"):
continue
print(f" {key:14s} -> {val}")
print()
def detect_pipeline_kind(model_index: Dict[str, Any]) -> str:
cls = str(model_index.get("_class_name", "")).lower()
if "zimage" in cls or ("transformer" in model_index and "unet" not in model_index):
return "zimage"
if "stable" in cls or "unet" in model_index:
return "sdxl_like"
return "unknown"
def iter_safetensors_files(directory: Path) -> Iterable[Path]:
if not directory.exists():
return []
return sorted(p for p in directory.iterdir() if p.is_file() and p.suffix == ".safetensors")
def count_params_from_safetensors(files: Iterable[Path]) -> int:
from safetensors import safe_open
total = 0
for file in files:
with safe_open(str(file), framework="np") as f:
for key in f.keys():
shape = f.get_slice(key).get_shape()
total += math.prod(shape)
return int(total)
def zimage_config_only_summary(model_dir: Path, include_params: bool) -> Dict[str, Any]:
model_index = read_model_index(model_dir)
te_cfg_path = model_dir / "text_encoder" / "config.json"
transformer_cfg_path = model_dir / "transformer" / "config.json"
vae_cfg_path = model_dir / "vae" / "config.json"
scheduler_cfg_path = model_dir / "scheduler" / "scheduler_config.json"
te_cfg = load_json(te_cfg_path) if te_cfg_path.exists() else {}
transformer_cfg = load_json(transformer_cfg_path) if transformer_cfg_path.exists() else {}
vae_cfg = load_json(vae_cfg_path) if vae_cfg_path.exists() else {}
scheduler_cfg = load_json(scheduler_cfg_path) if scheduler_cfg_path.exists() else {}
text_encoder_params = None
transformer_params = None
vae_params = None
if include_params:
text_encoder_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "text_encoder"))
transformer_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "transformer"))
vae_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "vae"))
print("[Text encoder]")
if te_cfg:
arch = te_cfg.get("architectures", [])
arch_name = arch[0] if isinstance(arch, list) and arch else "n/a"
print(f" architecture={arch_name}")
print(
" "
f"layers={te_cfg.get('num_hidden_layers', 'n/a')}, "
f"hidden={te_cfg.get('hidden_size', 'n/a')}, "
f"heads={te_cfg.get('num_attention_heads', 'n/a')}, "
f"intermediate={te_cfg.get('intermediate_size', 'n/a')}"
)
print(f" vocab={te_cfg.get('vocab_size', 'n/a')}, max_positions={te_cfg.get('max_position_embeddings', 'n/a')}")
else:
print(" [warn] missing text_encoder/config.json")
print(f" params={human_params(text_encoder_params)}")
print()
print("[Transformer]")
if transformer_cfg:
print(f" class={transformer_cfg.get('_class_name', 'n/a')}")
print(
" "
f"dim={transformer_cfg.get('dim', 'n/a')}, "
f"layers={transformer_cfg.get('n_layers', 'n/a')}, "
f"heads={transformer_cfg.get('n_heads', 'n/a')}"
)
print(f" in_channels={transformer_cfg.get('in_channels', 'n/a')}, cap_feat_dim={transformer_cfg.get('cap_feat_dim', 'n/a')}")
print(f" patch_size={transformer_cfg.get('all_patch_size', 'n/a')}, f_patch_size={transformer_cfg.get('all_f_patch_size', 'n/a')}")
else:
print(" [warn] missing transformer/config.json")
print(f" params={human_params(transformer_params)}")
print()
print("[VAE]")
if vae_cfg:
print(f" class={vae_cfg.get('_class_name', 'n/a')}")
print(
" "
f"sample_size={vae_cfg.get('sample_size', 'n/a')}, "
f"in_channels={vae_cfg.get('in_channels', 'n/a')}, "
f"latent_channels={vae_cfg.get('latent_channels', 'n/a')}, "
f"out_channels={vae_cfg.get('out_channels', 'n/a')}"
)
print(f" block_out_channels={vae_cfg.get('block_out_channels', 'n/a')}, scaling_factor={vae_cfg.get('scaling_factor', 'n/a')}")
else:
print(" [warn] missing vae/config.json")
print(f" params={human_params(vae_params)}")
print()
print("[Scheduler]")
if scheduler_cfg:
print(
" "
f"class={scheduler_cfg.get('_class_name', 'n/a')}, "
f"timesteps={scheduler_cfg.get('num_train_timesteps', 'n/a')}, "
f"shift={scheduler_cfg.get('shift', 'n/a')}"
)
else:
print(" [warn] missing scheduler/scheduler_config.json")
print()
return {
"kind": "zimage",
"pipeline": model_index,
"text_encoder": {"config": te_cfg, "params": text_encoder_params},
"transformer": {"config": transformer_cfg, "params": transformer_params},
"vae": {"config": vae_cfg, "params": vae_params},
"scheduler": {"config": scheduler_cfg},
}
def main() -> None:
parser = argparse.ArgumentParser(description="Inspect a local Diffusers-style repository layout.")
parser.add_argument("--model-dir", type=Path, default=Path(".."), help="Path to the diffusers pipeline directory.")
parser.add_argument("--device", default="cpu", help="Unused (kept for CLI compatibility).")
parser.add_argument("--fp16", action="store_true", help="Unused (kept for CLI compatibility).")
parser.add_argument("--config-only", action="store_true", help="Read JSON configs and print a summary.")
parser.add_argument("--params", action="store_true", help="Count parameters from *.safetensors headers (no tensor loading).")
parser.add_argument("--json-out", type=Path, default=None, help="Write a JSON summary to this path.")
args = parser.parse_args()
model_index = read_model_index(args.model_dir)
if not model_index:
raise SystemExit(f"model_index.json not found under {args.model_dir}")
describe_model_index(model_index)
kind = detect_pipeline_kind(model_index)
if not args.config_only:
raise SystemExit("Only --config-only mode is supported by this inspector.")
if kind != "zimage":
raise SystemExit(f"Unsupported pipeline kind: {kind} (expected ZImagePipeline-style layout)")
summary = zimage_config_only_summary(args.model_dir, include_params=args.params)
if args.json_out is not None:
args.json_out.parent.mkdir(parents=True, exist_ok=True)
args.json_out.write_text(json.dumps(summary, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
print(f"[info] wrote JSON summary to {args.json_out}")
if __name__ == "__main__":
main()