Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Upgrade an existing prototype DB .pt file to include artist names (label_names). | |
| This is useful for older prototype files that only store: | |
| - centers: [N, D] | |
| - labels: [N] | |
| We infer label_names from `dataset/` folder (sorted artist directories), matching | |
| `train_style_ddp.TriViewDataset` label assignment. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| def infer_label_names(dataset_dir: Path) -> list[str]: | |
| if not dataset_dir.exists(): | |
| raise FileNotFoundError(f"dataset dir not found: {dataset_dir}") | |
| names = sorted([p.name for p in dataset_dir.iterdir() if p.is_dir()]) | |
| if not names: | |
| raise RuntimeError(f"No artist folders found under: {dataset_dir}") | |
| return names | |
| def main() -> None: | |
| p = argparse.ArgumentParser(description="Add label_names to an existing prototype DB .pt") | |
| p.add_argument("--in", dest="in_path", required=True, help="Input .pt prototype file") | |
| p.add_argument("--out", dest="out_path", default=None, help="Output .pt (default: overwrite input)") | |
| p.add_argument("--dataset-dir", type=str, default="dataset", help="Dataset root to infer artist names from") | |
| args = p.parse_args() | |
| in_path = Path(args.in_path) | |
| out_path = Path(args.out_path) if args.out_path else in_path | |
| dataset_dir = Path(args.dataset_dir) | |
| obj = torch.load(str(in_path), map_location="cpu") | |
| if not isinstance(obj, dict) or "centers" not in obj or "labels" not in obj: | |
| raise ValueError("Unsupported prototype file format (expected dict with centers+labels).") | |
| if "label_names" in obj and isinstance(obj["label_names"], list) and obj["label_names"]: | |
| print("label_names already present; nothing to do.") | |
| if out_path != in_path: | |
| torch.save(obj, str(out_path)) | |
| print("saved:", out_path) | |
| return | |
| label_names = infer_label_names(dataset_dir) | |
| obj["label_names"] = label_names | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save(obj, str(out_path)) | |
| print("saved:", out_path) | |
| if __name__ == "__main__": | |
| main() | |