jsflow / back /sample_from_checkpoint_ddp.py
xiangzai's picture
Add files using upload-large-folder tool
5484dca verified
#!/usr/bin/env python3
"""
DDP 多卡采样脚本(单路径,不做 dual-compare,不保存 t_c 中间态图)。
用法(4 卡示例):
torchrun --nproc_per_node=4 sample_from_checkpoint_ddp.py \
--ckpt exps/jsflow-experiment/checkpoints/0290000.pt \
--out-dir ./my_samples_ddp \
--num-images 50000 \
--batch-size 16 \
--t-c 0.75 --steps-before-tc 100 --steps-after-tc 5 \
--sampler em_image_noise_before_tc
"""
from __future__ import annotations
import argparse
import math
import os
import sys
import types
import numpy as np
import torch
import torch.distributed as dist
from diffusers.models import AutoencoderKL
from PIL import Image
from tqdm import tqdm
from models.sit import SiT_models
from samplers import (
euler_maruyama_image_noise_before_tc_sampler,
euler_maruyama_image_noise_sampler,
euler_maruyama_sampler,
euler_ode_sampler,
)
def create_npz_from_sample_folder(sample_dir: str, num: int):
"""
将 sample_dir 下 000000.png... 组装为单个 .npz(arr_0)。
"""
samples = []
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(os.path.join(sample_dir, f"{i:06d}.png"))
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def semantic_dim_from_enc_type(enc_type):
if enc_type is None:
return 768
s = str(enc_type).lower()
if "vit-g" in s or "vitg" in s:
return 1536
if "vit-l" in s or "vitl" in s:
return 1024
if "vit-s" in s or "vits" in s:
return 384
return 768
def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
a = ckpt.get("args")
if a is None:
return None
if isinstance(a, argparse.Namespace):
return a
if isinstance(a, dict):
return argparse.Namespace(**a)
if isinstance(a, types.SimpleNamespace):
return argparse.Namespace(**vars(a))
return None
def load_vae(device: torch.device):
try:
from preprocessing import dnnlib
cache_dir = dnnlib.make_cache_dir_path("diffusers")
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ["HF_HOME"] = cache_dir
try:
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
cache_dir=cache_dir,
local_files_only=True,
).to(device)
vae.eval()
return vae
except Exception:
pass
candidate_dir = None
for root_dir in [
cache_dir,
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
]:
if not os.path.isdir(root_dir):
continue
for root, _, files in os.walk(root_dir):
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
candidate_dir = root
break
if candidate_dir is not None:
break
if candidate_dir is not None:
vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
vae.eval()
return vae
except Exception:
pass
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vae.eval()
return vae
def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
res = int(getattr(ta, "resolution", 256))
latent_size = res // 8
enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
z_dims = [semantic_dim_from_enc_type(enc_type)]
block_kwargs = {
"fused_attn": getattr(ta, "fused_attn", True),
"qk_norm": getattr(ta, "qk_norm", False),
}
cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
if ta.model not in SiT_models:
raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
model = SiT_models[ta.model](
input_size=latent_size,
num_classes=int(getattr(ta, "num_classes", 1000)),
use_cfg=(cfg_prob > 0),
z_dims=z_dims,
encoder_depth=int(getattr(ta, "encoder_depth", 8)),
**block_kwargs,
).to(device)
return model, z_dims[0]
def resolve_tc_schedule(cli, ta):
sb = cli.steps_before_tc
sa = cli.steps_after_tc
tc = cli.t_c
if sb is None and sa is None:
return None, None, None
if sb is None or sa is None:
print("使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。", file=sys.stderr)
sys.exit(1)
if tc is None:
tc = getattr(ta, "t_c", None) if ta is not None else None
if tc is None:
print("分段采样需要 --t-c,或检查点 args 中含 t_c。", file=sys.stderr)
sys.exit(1)
return float(tc), int(sb), int(sa)
def parse_cli():
p = argparse.ArgumentParser(description="REG DDP 检查点采样(单路径,无 at_tc 图)")
p.add_argument("--ckpt", type=str, required=True)
p.add_argument("--out-dir", type=str, required=True)
p.add_argument("--num-images", type=int, required=True)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--weights", type=str, choices=("ema", "model"), default="ema")
p.add_argument("--device", type=str, default="cuda")
p.add_argument("--num-steps", type=int, default=50)
p.add_argument("--t-c", type=float, default=None)
p.add_argument("--steps-before-tc", type=int, default=None)
p.add_argument("--steps-after-tc", type=int, default=None)
p.add_argument("--cfg-scale", type=float, default=1.0)
p.add_argument("--cls-cfg-scale", type=float, default=0.0)
p.add_argument("--guidance-low", type=float, default=0.0)
p.add_argument("--guidance-high", type=float, default=1.0)
p.add_argument("--path-type", type=str, default=None, choices=["linear", "cosine"])
p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
p.add_argument("--model", type=str, default=None)
p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
p.add_argument("--num-classes", type=int, default=1000)
p.add_argument("--encoder-depth", type=int, default=None)
p.add_argument("--enc-type", type=str, default=None)
p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
p.add_argument("--cfg-prob", type=float, default=None)
p.add_argument(
"--sampler",
type=str,
default="em_image_noise_before_tc",
choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
)
p.add_argument(
"--save-fixed-trajectory",
action="store_true",
help="保存本 rank 轨迹(npy)到 out-dir/trajectory_rank{rank}",
)
p.add_argument(
"--save-npz",
action=argparse.BooleanOptionalAction,
default=True,
help="采样结束后由 rank0 汇总 PNG 并保存 out-dir.npz",
)
return p.parse_args()
def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
imgs = vae.decode((latents - latents_bias) / latents_scale).sample
imgs = (imgs + 1) / 2.0
imgs = torch.clamp(imgs, 0, 1)
return (
(imgs * 255.0)
.round()
.to(torch.uint8)
.permute(0, 2, 3, 1)
.cpu()
.numpy()
)
def init_ddp():
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.init_process_group(backend="nccl", init_method="env://")
torch.cuda.set_device(local_rank)
return True, rank, world_size, local_rank
return False, 0, 1, 0
def main():
cli = parse_cli()
use_ddp, rank, world_size, local_rank = init_ddp()
if torch.cuda.is_available():
device = torch.device(f"cuda:{local_rank}" if use_ddp else cli.device)
torch.backends.cuda.matmul.allow_tf32 = True
else:
device = torch.device("cpu")
try:
ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
except TypeError:
ckpt = torch.load(cli.ckpt, map_location="cpu")
ta = load_train_args_from_ckpt(ckpt)
if ta is None:
if cli.model is None or cli.resolution is None or cli.enc_type is None:
print("检查点中无 args,请至少指定:--model --resolution --enc-type", file=sys.stderr)
sys.exit(1)
ta = argparse.Namespace(
model=cli.model,
resolution=cli.resolution,
num_classes=cli.num_classes if cli.num_classes is not None else 1000,
encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
enc_type=cli.enc_type,
fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
)
else:
if cli.model is not None:
ta.model = cli.model
if cli.resolution is not None:
ta.resolution = cli.resolution
if cli.num_classes is not None:
ta.num_classes = cli.num_classes
if cli.encoder_depth is not None:
ta.encoder_depth = cli.encoder_depth
if cli.enc_type is not None:
ta.enc_type = cli.enc_type
if cli.fused_attn is not None:
ta.fused_attn = cli.fused_attn
if cli.qk_norm is not None:
ta.qk_norm = cli.qk_norm
if cli.cfg_prob is not None:
ta.cfg_prob = cli.cfg_prob
path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
tc_split = resolve_tc_schedule(cli, ta)
if rank == 0:
if tc_split[0] is not None:
print(
f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]}"
)
else:
print(f"时间网格:均匀 num_steps={cli.num_steps}")
if cli.sampler == "ode":
sampler_fn = euler_ode_sampler
elif cli.sampler == "em":
sampler_fn = euler_maruyama_sampler
elif cli.sampler == "em_image_noise_before_tc":
sampler_fn = euler_maruyama_image_noise_before_tc_sampler
else:
sampler_fn = euler_maruyama_image_noise_sampler
model, cls_dim = build_model_from_train_args(ta, device)
wkey = cli.weights
if wkey not in ckpt:
raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
state = ckpt[wkey]
if cli.legacy:
from utils import load_legacy_checkpoints
state = load_legacy_checkpoints(
state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
)
model.load_state_dict(state, strict=True)
model.eval()
vae = load_vae(device)
latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
os.makedirs(cli.out_dir, exist_ok=True)
traj_dir = None
if cli.save_fixed_trajectory and cli.sampler != "em":
traj_dir = os.path.join(cli.out_dir, f"trajectory_rank{rank}")
os.makedirs(traj_dir, exist_ok=True)
latent_size = int(getattr(ta, "resolution", 256)) // 8
n_total = int(cli.num_images)
b = max(1, int(cli.batch_size))
global_batch_size = b * world_size
total_samples = int(math.ceil(n_total / global_batch_size) * global_batch_size)
samples_needed_this_gpu = int(total_samples // world_size)
if samples_needed_this_gpu % b != 0:
raise ValueError("samples_needed_this_gpu must be divisible by per-rank batch size")
iterations = int(samples_needed_this_gpu // b)
seed_rank = int(cli.seed) + int(rank)
torch.manual_seed(seed_rank)
if device.type == "cuda":
torch.cuda.manual_seed_all(seed_rank)
if rank == 0:
print(f"Total number of images that will be sampled: {total_samples}")
pbar = range(iterations)
pbar = tqdm(pbar, desc="sampling") if rank == 0 else pbar
total = 0
written_local = 0
for _ in pbar:
cur = b
z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
cls_z = torch.randn(cur, cls_dim, device=device)
with torch.no_grad():
em_kw = dict(
num_steps=cli.num_steps,
cfg_scale=cli.cfg_scale,
guidance_low=cli.guidance_low,
guidance_high=cli.guidance_high,
path_type=path_type,
cls_latents=cls_z,
args=sampler_args,
)
if tc_split[0] is not None:
em_kw["t_c"] = tc_split[0]
em_kw["num_steps_before_tc"] = tc_split[1]
em_kw["num_steps_after_tc"] = tc_split[2]
if cli.save_fixed_trajectory and cli.sampler != "em":
if cli.sampler == "em_image_noise_before_tc":
latents, traj = sampler_fn(
model, z, y, **em_kw, return_trajectory=True
)
else:
latents, traj = sampler_fn(
model, z, y, **em_kw, return_trajectory=True
)
else:
latents = sampler_fn(model, z, y, **em_kw)
traj = None
latents = latents.to(torch.float32)
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
for i, img in enumerate(imgs):
gidx = i * world_size + rank + total
if gidx < n_total:
Image.fromarray(img).save(os.path.join(cli.out_dir, f"{gidx:06d}.png"))
written_local += 1
if traj is not None and traj_dir is not None:
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
first_idx = rank + total
if first_idx < n_total:
np.save(os.path.join(traj_dir, f"{first_idx:06d}_traj.npy"), traj_np)
total += global_batch_size
if use_ddp:
dist.barrier()
if rank == 0 and hasattr(pbar, "close"):
pbar.close()
if use_ddp:
dist.barrier()
if rank == 0:
if cli.save_npz:
create_npz_from_sample_folder(cli.out_dir, n_total)
print(f"Done. Saved {n_total} images under {cli.out_dir} (world_size={world_size}).")
if use_ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()