| |
| """ |
| DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。 |
| |
| 用法(4 卡示例): |
| torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \ |
| --ckpt exps/jsflow-experiment/checkpoints/0290000.pt \ |
| --out-dir ./my_samples_ddp \ |
| --num-images 50000 \ |
| --batch-size 16 \ |
| --t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \ |
| --sampler em_image_noise_before_tc |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import os |
| import sys |
| import types |
| import numpy as np |
|
|
| import torch |
| import torch.distributed as dist |
| from diffusers.models import AutoencoderKL |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from models.sit import SiT_models |
| from samplers import ( |
| euler_maruyama_image_noise_before_tc_sampler, |
| euler_maruyama_image_noise_sampler, |
| euler_maruyama_sampler, |
| euler_ode_sampler, |
| ) |
|
|
|
|
| def create_npz_from_sample_folder(sample_dir: str, num: int): |
| """ |
| 将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。 |
| """ |
| samples = [] |
| for i in tqdm(range(num), desc="Building .npz file from samples"): |
| sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png")) |
| sample_np = np.asarray(sample_pil).astype(np.uint8) |
| samples.append(sample_np) |
| samples = np.stack(samples) |
| npz_path = f"{sample_dir}.npz" |
| np.savez(npz_path, arr_0=samples) |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
| return npz_path |
|
|
|
|
| def semantic_dim_from_enc_type(enc_type): |
| if enc_type is None: |
| return 768 |
| s = str(enc_type).lower() |
| if "vit-g" in s or "vitg" in s: |
| return 1536 |
| if "vit-l" in s or "vitl" in s: |
| return 1024 |
| if "vit-s" in s or "vits" in s: |
| return 384 |
| return 768 |
|
|
|
|
| def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None: |
| a = ckpt.get("args") |
| if a is None: |
| return None |
| if isinstance(a, argparse.Namespace): |
| return a |
| if isinstance(a, dict): |
| return argparse.Namespace(**a) |
| if isinstance(a, types.SimpleNamespace): |
| return argparse.Namespace(**vars(a)) |
| return None |
|
|
|
|
| def load_vae(device: torch.device): |
| try: |
| from preprocessing import dnnlib |
|
|
| cache_dir = dnnlib.make_cache_dir_path("diffusers") |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") |
| os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") |
| os.environ["HF_HOME"] = cache_dir |
| try: |
| vae = AutoencoderKL.from_pretrained( |
| "stabilityai/sd-vae-ft-mse", |
| cache_dir=cache_dir, |
| local_files_only=True, |
| ).to(device) |
| vae.eval() |
| return vae |
| except Exception: |
| pass |
| candidate_dir = None |
| for root_dir in [ |
| cache_dir, |
| os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"), |
| os.path.join(os.path.expanduser("~"), ".cache", "diffusers"), |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"), |
| ]: |
| if not os.path.isdir(root_dir): |
| continue |
| for root, _, files in os.walk(root_dir): |
| if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"): |
| candidate_dir = root |
| break |
| if candidate_dir is not None: |
| break |
| if candidate_dir is not None: |
| vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device) |
| vae.eval() |
| return vae |
| except Exception: |
| pass |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) |
| vae.eval() |
| return vae |
|
|
|
|
| def build_model_from_train_args(ta: argparse.Namespace, device: torch.device): |
| res = int(getattr(ta, "resolution", 256)) |
| latent_size = res // 8 |
| enc_type = getattr(ta, "enc_type", "dinov2-vit-b") |
| z_dims = [semantic_dim_from_enc_type(enc_type)] |
| block_kwargs = { |
| "fused_attn": getattr(ta, "fused_attn", True), |
| "qk_norm": getattr(ta, "qk_norm", False), |
| } |
| cfg_prob = float(getattr(ta, "cfg_prob", 0.1)) |
| if ta.model not in SiT_models: |
| raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}") |
| model = SiT_models[ta.model]( |
| input_size=latent_size, |
| num_classes=int(getattr(ta, "num_classes", 1000)), |
| use_cfg=(cfg_prob > 0), |
| z_dims=z_dims, |
| encoder_depth=int(getattr(ta, "encoder_depth", 8)), |
| **block_kwargs, |
| ).to(device) |
| return model, z_dims[0] |
|
|
|
|
| def resolve_tc_schedule(cli, ta): |
| sb = cli.steps_before_tc |
| sa = cli.steps_after_tc |
| tc = cli.t_c |
| if sb is None and sa is None: |
| return None, None, None |
| if sb is None or sa is None: |
| print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr) |
| sys.exit(1) |
| if tc is None: |
| tc = getattr(ta, "t_c", None) if ta is not None else None |
| if tc is None: |
| print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr) |
| sys.exit(1) |
| return float(tc), int(sb), int(sa) |
|
|
|
|
| def parse_cli(): |
| p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)") |
| p.add_argument("--ckpt", type=str, required=True) |
| p.add_argument("--out-dir", type=str, required=True) |
| p.add_argument("--num-images", type=int, required=True) |
| p.add_argument("--batch-size", type=int, default=16) |
| p.add_argument("--seed", type=int, default=0) |
| p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema") |
| p.add_argument("--device", type=str, default="cuda") |
| p.add_argument("--num-steps", type=int, default=50) |
| p.add_argument("--t-c", type=float, default=None) |
| p.add_argument("--steps-before-tc", type=int, default=None) |
| p.add_argument("--steps-after-tc", type=int, default=None) |
| p.add_argument("--cfg-scale", type=float, default=1.0) |
| p.add_argument("--cls-cfg-scale", type=float, default=0.0) |
| p.add_argument("--guidance-low", type=float, default=0.0) |
| p.add_argument("--guidance-high", type=float, default=1.0) |
| p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"]) |
| p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) |
| p.add_argument("--model", type=str, default=None) |
| p.add_argument("--resolution", type=int, default=None, choices=[256, 512]) |
| p.add_argument("--num-classes", type=int, default=1000) |
| p.add_argument("--encoder-depth", type=int, default=None) |
| p.add_argument("--enc-type", type=str, default=None) |
| p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None) |
| p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None) |
| p.add_argument("--cfg-prob", type=float, default=None) |
| p.add_argument( |
| "--sampler", |
| type=str, |
| default="em_image_noise_before_tc", |
| choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"], |
| ) |
| p.add_argument( |
| "--save-fixed-trajectory", |
| action="store_true", |
| help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}", |
| ) |
| p.add_argument( |
| "--save-npz", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz", |
| ) |
| return p.parse_args() |
|
|
|
|
| def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae): |
| imgs = vae.decode((latents - latents_bias) / latents_scale).sample |
| imgs = (imgs + 1) / 2.0 |
| imgs = torch.clamp(imgs, 0, 1) |
| return ( |
| (imgs * 255.0) |
| .round() |
| .to(torch.uint8) |
| .permute(0, 2, 3, 1) |
| .cpu() |
| .numpy() |
| ) |
|
|
|
|
| def init_ddp(): |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| dist.init_process_group(backend="nccl", init_method="env://") |
| torch.cuda.set_device(local_rank) |
| return True, rank, world_size, local_rank |
| return False, 0, 1, 0 |
|
|
|
|
| def main(): |
| cli = parse_cli() |
| use_ddp, rank, world_size, local_rank = init_ddp() |
|
|
| if torch.cuda.is_available(): |
| device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| else: |
| device = torch.device("cpu") |
|
|
| try: |
| ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False) |
| except TypeError: |
| ckpt = torch.load(cli.ckpt, map_location="cpu") |
| ta = load_train_args_from_ckpt(ckpt) |
| if ta is None: |
| if cli.model is None or cli.resolution is None or cli.enc_type is None: |
| print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr) |
| sys.exit(1) |
| ta = argparse.Namespace( |
| model=cli.model, |
| resolution=cli.resolution, |
| num_classes=cli.num_classes if cli.num_classes is not None else 1000, |
| encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8, |
| enc_type=cli.enc_type, |
| fused_attn=cli.fused_attn if cli.fused_attn is not None else True, |
| qk_norm=cli.qk_norm if cli.qk_norm is not None else False, |
| cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1, |
| ) |
| else: |
| if cli.model is not None: |
| ta.model = cli.model |
| if cli.resolution is not None: |
| ta.resolution = cli.resolution |
| if cli.num_classes is not None: |
| ta.num_classes = cli.num_classes |
| if cli.encoder_depth is not None: |
| ta.encoder_depth = cli.encoder_depth |
| if cli.enc_type is not None: |
| ta.enc_type = cli.enc_type |
| if cli.fused_attn is not None: |
| ta.fused_attn = cli.fused_attn |
| if cli.qk_norm is not None: |
| ta.qk_norm = cli.qk_norm |
| if cli.cfg_prob is not None: |
| ta.cfg_prob = cli.cfg_prob |
|
|
| path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear") |
| tc_split = resolve_tc_schedule(cli, ta) |
|
|
| if rank == 0: |
| if tc_split[0] is not None: |
| print( |
| f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}" |
| ) |
| else: |
| print(f"时间网格:均匀 num_steps={cli.num_steps}") |
|
|
| if cli.sampler == "ode": |
| sampler_fn = euler_ode_sampler |
| elif cli.sampler == "em": |
| sampler_fn = euler_maruyama_sampler |
| elif cli.sampler == "em_image_noise_before_tc": |
| sampler_fn = euler_maruyama_image_noise_before_tc_sampler |
| else: |
| sampler_fn = euler_maruyama_image_noise_sampler |
|
|
| model, cls_dim = build_model_from_train_args(ta, device) |
| wkey = cli.weights |
| if wkey not in ckpt: |
| raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}") |
| state = ckpt[wkey] |
| if cli.legacy: |
| from utils import load_legacy_checkpoints |
|
|
| state = load_legacy_checkpoints( |
| state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8)) |
| ) |
| model.load_state_dict(state, strict=True) |
| model.eval() |
|
|
| vae = load_vae(device) |
| latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1) |
| latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1) |
| sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale)) |
|
|
| os.makedirs(cli.out_dir, exist_ok=True) |
| traj_dir = None |
| if cli.save_fixed_trajectory and cli.sampler != "em": |
| traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}") |
| os.makedirs(traj_dir, exist_ok=True) |
|
|
| latent_size = int(getattr(ta, "resolution", 256)) // 8 |
| n_total = int(cli.num_images) |
| b = max(1, int(cli.batch_size)) |
| global_batch_size = b * world_size |
| total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size) |
| samples_needed_this_gpu = int(total_samples // world_size) |
| if samples_needed_this_gpu % b != 0: |
| raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size") |
| iterations = int(samples_needed_this_gpu // b) |
|
|
| seed_rank = int(cli.seed) + int(rank) |
| torch.manual_seed(seed_rank) |
| if device.type == "cuda": |
| torch.cuda.manual_seed_all(seed_rank) |
|
|
| if rank == 0: |
| print(f"Total number of images that will be sampled: {total_samples}") |
| pbar = range(iterations) |
| pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar |
| total = 0 |
| written_local = 0 |
| for _ in pbar: |
| cur = b |
| z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device) |
| y = torch.randint(0, int(ta.num_classes), (cur,), device=device) |
| cls_z = torch.randn(cur, cls_dim, device=device) |
|
|
| with torch.no_grad(): |
| em_kw = dict( |
| num_steps=cli.num_steps, |
| cfg_scale=cli.cfg_scale, |
| guidance_low=cli.guidance_low, |
| guidance_high=cli.guidance_high, |
| path_type=path_type, |
| cls_latents=cls_z, |
| args=sampler_args, |
| ) |
| if tc_split[0] is not None: |
| em_kw["t_c"] = tc_split[0] |
| em_kw["num_steps_before_tc"] = tc_split[1] |
| em_kw["num_steps_after_tc"] = tc_split[2] |
|
|
| if cli.save_fixed_trajectory and cli.sampler != "em": |
| if cli.sampler == "em_image_noise_before_tc": |
| latents, traj = sampler_fn( |
| model, z, y, **em_kw, return_trajectory=True |
| ) |
| else: |
| latents, traj = sampler_fn( |
| model, z, y, **em_kw, return_trajectory=True |
| ) |
| else: |
| latents = sampler_fn(model, z, y, **em_kw) |
| traj = None |
|
|
| latents = latents.to(torch.float32) |
| imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae) |
| for i, img in enumerate(imgs): |
| gidx = i * world_size + rank + total |
| if gidx < n_total: |
| Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png")) |
| written_local += 1 |
|
|
| if traj is not None and traj_dir is not None: |
| traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy() |
| first_idx = rank + total |
| if first_idx < n_total: |
| np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np) |
|
|
| total += global_batch_size |
| if use_ddp: |
| dist.barrier() |
| if rank == 0 and hasattr(pbar, "close"): |
| pbar.close() |
|
|
| if use_ddp: |
| dist.barrier() |
| if rank == 0: |
| if cli.save_npz: |
| create_npz_from_sample_folder(cli.out_dir, n_total) |
| print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).") |
| if use_ddp: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|