ringhyacinth commited on
Commit
8c9bbe5
1 Parent(s): df648bb

Upload txt2img.py

Browse files
Files changed (1) hide show
  1. txt2img.py +290 -0
txt2img.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import torch
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+ from PIL import Image
6
+ from tqdm import tqdm, trange
7
+ from itertools import islice
8
+ from einops import rearrange
9
+ from torchvision.utils import make_grid
10
+ import time
11
+ from pytorch_lightning import seed_everything
12
+ from torch import autocast
13
+ from contextlib import contextmanager, nullcontext
14
+
15
+ from ldm.util import instantiate_from_config
16
+ from ldm.models.diffusion.ddim import DDIMSampler
17
+ from ldm.models.diffusion.plms import PLMSSampler
18
+
19
+
20
+ def chunk(it, size):
21
+ it = iter(it)
22
+ return iter(lambda: tuple(islice(it, size)), ())
23
+
24
+
25
+ def load_model_from_config(config, ckpt, verbose=False):
26
+ print(f"Loading model from {ckpt}")
27
+ pl_sd = torch.load(ckpt, map_location="cpu")
28
+ if "global_step" in pl_sd:
29
+ print(f"Global Step: {pl_sd['global_step']}")
30
+ sd = pl_sd["state_dict"]
31
+ model = instantiate_from_config(config.model)
32
+ m, u = model.load_state_dict(sd, strict=False)
33
+ if len(m) > 0 and verbose:
34
+ print("missing keys:")
35
+ print(m)
36
+ if len(u) > 0 and verbose:
37
+ print("unexpected keys:")
38
+ print(u)
39
+
40
+ model.cuda()
41
+ model.eval()
42
+ return model
43
+
44
+
45
+ def main():
46
+ parser = argparse.ArgumentParser()
47
+
48
+ parser.add_argument(
49
+ "--prompt",
50
+ type=str,
51
+ nargs="?",
52
+ default="a painting of a virus monster playing guitar",
53
+ help="the prompt to render"
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--outdir",
58
+ type=str,
59
+ nargs="?",
60
+ help="dir to write results to",
61
+ default="outputs/txt2img-samples"
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--skip_grid",
66
+ action='store_true',
67
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--skip_save",
72
+ action='store_true',
73
+ help="do not save indiviual samples. For speed measurements.",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--ddim_steps",
78
+ type=str,
79
+ default="50",
80
+ help="number of ddim sampling steps",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--plms",
85
+ action='store_true',
86
+ help="use plms sampling",
87
+ )
88
+ parser.add_argument(
89
+ "--fixed_code",
90
+ action='store_true',
91
+ help="if enabled, uses the same starting code across all samples ",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--ddim_eta",
96
+ type=str,
97
+ default="0.0",
98
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
99
+ )
100
+ parser.add_argument(
101
+ "--n_iter",
102
+ type=int,
103
+ default=1,
104
+ help="sample this often",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--H",
109
+ type=int,
110
+ default=256,
111
+ help="image height, in pixel space",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--W",
116
+ type=int,
117
+ default=256,
118
+ help="image width, in pixel space",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--C",
123
+ type=int,
124
+ default=4,
125
+ help="latent channels",
126
+ )
127
+ parser.add_argument(
128
+ "--f",
129
+ type=int,
130
+ default=8,
131
+ help="downsampling factor, most often 8 or 16",
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--n_samples",
136
+ type=str,
137
+ default="8",
138
+ help="how many samples to produce for each given prompt. A.k.a batch size",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--n_rows",
143
+ type=int,
144
+ default=0,
145
+ help="rows in the grid (default: n_samples)",
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--scale",
150
+ type=str,
151
+ default='5.0',
152
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--dyn",
157
+ type=float,
158
+ help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)",
159
+ )
160
+ parser.add_argument(
161
+ "--from-file",
162
+ type=str,
163
+ help="if specified, load prompts from this file",
164
+ )
165
+ parser.add_argument(
166
+ "--config",
167
+ type=str,
168
+ default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml",
169
+ help="path to config which constructs model",
170
+ )
171
+ parser.add_argument(
172
+ "--ckpt",
173
+ type=str,
174
+ default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
175
+ help="path to checkpoint of model",
176
+ )
177
+ parser.add_argument(
178
+ "--seed",
179
+ type=int,
180
+ default=42,
181
+ help="the seed (for reproducible sampling)",
182
+ )
183
+ parser.add_argument(
184
+ "--precision",
185
+ type=str,
186
+ help="evaluate at this precision",
187
+ choices=["full", "autocast"],
188
+ default="autocast"
189
+ )
190
+ opt = parser.parse_args()
191
+ opt.n_samples = int(opt.n_samples)
192
+ opt.ddim_steps = int(opt.ddim_steps)
193
+ opt.scale = float(opt.scale)
194
+ opt.ddim_eta = float(opt.ddim_eta)
195
+ opt.seed = int(opt.seed)
196
+ seed_everything(opt.seed)
197
+
198
+ config = OmegaConf.load(f"{opt.config}")
199
+ model = load_model_from_config(config, f"{opt.ckpt}")
200
+
201
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
202
+ model = model.to(device)
203
+
204
+ if opt.plms:
205
+ sampler = PLMSSampler(model)
206
+ else:
207
+ sampler = DDIMSampler(model)
208
+
209
+ os.makedirs(opt.outdir, exist_ok=True)
210
+ outpath = opt.outdir
211
+
212
+ batch_size = opt.n_samples
213
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
214
+ if not opt.from_file:
215
+ prompt = opt.prompt
216
+ assert prompt is not None
217
+ data = [batch_size * [prompt]]
218
+
219
+ else:
220
+ print(f"reading prompts from {opt.from_file}")
221
+ with open(opt.from_file, "r") as f:
222
+ data = f.read().splitlines()
223
+ data = list(chunk(data, batch_size))
224
+
225
+ sample_path = os.path.join(outpath, "samples")
226
+ os.makedirs(sample_path, exist_ok=True)
227
+ base_count = len(os.listdir(sample_path))
228
+ grid_count = len(os.listdir(outpath)) - 1
229
+
230
+ start_code = None
231
+ if opt.fixed_code:
232
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
233
+
234
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
235
+ with torch.no_grad():
236
+ with precision_scope("cuda"):
237
+ with model.ema_scope():
238
+ tic = time.time()
239
+ all_samples = list()
240
+ for n in trange(opt.n_iter, desc="Sampling"):
241
+ for prompts in tqdm(data, desc="data"):
242
+ uc = None
243
+ if opt.scale != 1.0:
244
+ uc = model.get_learned_conditioning(batch_size * [""])
245
+ if isinstance(prompts, tuple):
246
+ prompts = list(prompts)
247
+ c = model.get_learned_conditioning(prompts)
248
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
249
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
250
+ conditioning=c,
251
+ batch_size=opt.n_samples,
252
+ shape=shape,
253
+ verbose=False,
254
+ unconditional_guidance_scale=opt.scale,
255
+ unconditional_conditioning=uc,
256
+ eta=opt.ddim_eta,
257
+ dynamic_threshold=opt.dyn,
258
+ x_T=start_code)
259
+
260
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
261
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
262
+
263
+ if not opt.skip_save:
264
+ for x_sample in x_samples_ddim:
265
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
266
+ Image.fromarray(x_sample.astype(np.uint8)).save(
267
+ os.path.join(sample_path, f"{base_count:05}.png"))
268
+ base_count += 1
269
+ all_samples.append(x_samples_ddim)
270
+
271
+ if not opt.skip_grid:
272
+ # additionally, save as grid
273
+ grid = torch.stack(all_samples, 0)
274
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
275
+ grid = make_grid(grid, nrow=n_rows)
276
+
277
+ # to image
278
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
279
+ Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
280
+ grid_count += 1
281
+
282
+ toc = time.time()
283
+
284
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
285
+ f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
286
+ f" \nEnjoy.")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()