Dragunflie-420 commited on
Commit
c26cbcb
1 Parent(s): 4de488d

Create sample_ddp.py

Browse files
Files changed (1) hide show
  1. sample_ddp.py +166 -0
sample_ddp.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Samples a large number of images from a pre-trained DiT model using DDP.
9
+ Subsequently saves a .npz file that can be used to compute FID and other
10
+ evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
11
+
12
+ For a simple single-GPU/CPU sampling script, see sample.py.
13
+ """
14
+ import torch
15
+ import torch.distributed as dist
16
+ from models import DiT_models
17
+ from download import find_model
18
+ from diffusion import create_diffusion
19
+ from diffusers.models import AutoencoderKL
20
+ from tqdm import tqdm
21
+ import os
22
+ from PIL import Image
23
+ import numpy as np
24
+ import math
25
+ import argparse
26
+
27
+
28
+ def create_npz_from_sample_folder(sample_dir, num=50_000):
29
+ """
30
+ Builds a single .npz file from a folder of .png samples.
31
+ """
32
+ samples = []
33
+ for i in tqdm(range(num), desc="Building .npz file from samples"):
34
+ sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
35
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
36
+ samples.append(sample_np)
37
+ samples = np.stack(samples)
38
+ assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
39
+ npz_path = f"{sample_dir}.npz"
40
+ np.savez(npz_path, arr_0=samples)
41
+ print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
42
+ return npz_path
43
+
44
+
45
+ def main(args):
46
+ """
47
+ Run sampling.
48
+ """
49
+ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
50
+ assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
51
+ torch.set_grad_enabled(False)
52
+
53
+ # Setup DDP:
54
+ dist.init_process_group("nccl")
55
+ rank = dist.get_rank()
56
+ device = rank % torch.cuda.device_count()
57
+ seed = args.global_seed * dist.get_world_size() + rank
58
+ torch.manual_seed(seed)
59
+ torch.cuda.set_device(device)
60
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
61
+
62
+ if args.ckpt is None:
63
+ assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download."
64
+ assert args.image_size in [256, 512]
65
+ assert args.num_classes == 1000
66
+
67
+ # Load model:
68
+ latent_size = args.image_size // 8
69
+ model = DiT_models[args.model](
70
+ input_size=latent_size,
71
+ num_classes=args.num_classes
72
+ ).to(device)
73
+ # Auto-download a pre-trained model or load a custom DiT checkpoint from train.py:
74
+ ckpt_path = args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt"
75
+ state_dict = find_model(ckpt_path)
76
+ model.load_state_dict(state_dict)
77
+ model.eval() # important!
78
+ diffusion = create_diffusion(str(args.num_sampling_steps))
79
+ vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
80
+ assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
81
+ using_cfg = args.cfg_scale > 1.0
82
+
83
+ # Create folder to save samples:
84
+ model_string_name = args.model.replace("/", "-")
85
+ ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
86
+ folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \
87
+ f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
88
+ sample_folder_dir = f"{args.sample_dir}/{folder_name}"
89
+ if rank == 0:
90
+ os.makedirs(sample_folder_dir, exist_ok=True)
91
+ print(f"Saving .png samples at {sample_folder_dir}")
92
+ dist.barrier()
93
+
94
+ # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
95
+ n = args.per_proc_batch_size
96
+ global_batch_size = n * dist.get_world_size()
97
+ # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
98
+ total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
99
+ if rank == 0:
100
+ print(f"Total number of images that will be sampled: {total_samples}")
101
+ assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
102
+ samples_needed_this_gpu = int(total_samples // dist.get_world_size())
103
+ assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
104
+ iterations = int(samples_needed_this_gpu // n)
105
+ pbar = range(iterations)
106
+ pbar = tqdm(pbar) if rank == 0 else pbar
107
+ total = 0
108
+ for _ in pbar:
109
+ # Sample inputs:
110
+ z = torch.randn(n, model.in_channels, latent_size, latent_size, device=device)
111
+ y = torch.randint(0, args.num_classes, (n,), device=device)
112
+
113
+ # Setup classifier-free guidance:
114
+ if using_cfg:
115
+ z = torch.cat([z, z], 0)
116
+ y_null = torch.tensor([1000] * n, device=device)
117
+ y = torch.cat([y, y_null], 0)
118
+ model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)
119
+ sample_fn = model.forward_with_cfg
120
+ else:
121
+ model_kwargs = dict(y=y)
122
+ sample_fn = model.forward
123
+
124
+ # Sample images:
125
+ samples = diffusion.p_sample_loop(
126
+ sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
127
+ )
128
+ if using_cfg:
129
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
130
+
131
+ samples = vae.decode(samples / 0.18215).sample
132
+ samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
133
+
134
+ # Save samples to disk as individual .png files
135
+ for i, sample in enumerate(samples):
136
+ index = i * dist.get_world_size() + rank + total
137
+ Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
138
+ total += global_batch_size
139
+
140
+ # Make sure all processes have finished saving their samples before attempting to convert to .npz
141
+ dist.barrier()
142
+ if rank == 0:
143
+ create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
144
+ print("Done.")
145
+ dist.barrier()
146
+ dist.destroy_process_group()
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
152
+ parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
153
+ parser.add_argument("--sample-dir", type=str, default="samples")
154
+ parser.add_argument("--per-proc-batch-size", type=int, default=32)
155
+ parser.add_argument("--num-fid-samples", type=int, default=50_000)
156
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
157
+ parser.add_argument("--num-classes", type=int, default=1000)
158
+ parser.add_argument("--cfg-scale", type=float, default=1.5)
159
+ parser.add_argument("--num-sampling-steps", type=int, default=250)
160
+ parser.add_argument("--global-seed", type=int, default=0)
161
+ parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
162
+ help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
163
+ parser.add_argument("--ckpt", type=str, default=None,
164
+ help="Optional path to a DiT checkpoint (default: auto-download a pre-trained DiT-XL/2 model).")
165
+ args = parser.parse_args()
166
+ main(args)