Juno360219 commited on
Commit
d88ce1a
1 Parent(s): aa9ba50

Create common_google_play_services_enable_text

Browse files
common_google_play_services_enable_text ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from tqdm import tqdm, trange
8
+ from imwatermark import WatermarkEncoder
9
+ from itertools import islice
10
+ from einops import rearrange
11
+ from torchvision.utils import make_grid
12
+ import time
13
+ from pytorch_lightning import seed_everything
14
+ from torch import autocast
15
+ from contextlib import contextmanager, nullcontext
16
+
17
+ from ldm.util import instantiate_from_config
18
+ from ldm.models.diffusion.ddim import DDIMSampler
19
+ from ldm.models.diffusion.plms import PLMSSampler
20
+
21
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
22
+ from transformers import AutoFeatureExtractor
23
+
24
+
25
+ # load safety model
26
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
27
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
28
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
29
+
30
+
31
+ def chunk(it, size):
32
+ it = iter(it)
33
+ return iter(lambda: tuple(islice(it, size)), ())
34
+
35
+
36
+ def numpy_to_pil(images):
37
+ """
38
+ Convert a numpy image or a batch of images to a PIL image.
39
+ """
40
+ if images.ndim == 3:
41
+ images = images[None, ...]
42
+ images = (images * 255).round().astype("uint8")
43
+ pil_images = [Image.fromarray(image) for image in images]
44
+
45
+ return pil_images
46
+
47
+
48
+ def load_model_from_config(config, ckpt, verbose=False):
49
+ print(f"Loading model from {ckpt}")
50
+ pl_sd = torch.load(ckpt, map_location="cpu")
51
+ if "global_step" in pl_sd:
52
+ print(f"Global Step: {pl_sd['global_step']}")
53
+ sd = pl_sd["state_dict"]
54
+ model = instantiate_from_config(config.model)
55
+ m, u = model.load_state_dict(sd, strict=False)
56
+ if len(m) > 0 and verbose:
57
+ print("missing keys:")
58
+ print(m)
59
+ if len(u) > 0 and verbose:
60
+ print("unexpected keys:")
61
+ print(u)
62
+
63
+ model.cuda()
64
+ model.eval()
65
+ return model
66
+
67
+
68
+ def put_watermark(img, wm_encoder=None):
69
+ if wm_encoder is not None:
70
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
71
+ img = wm_encoder.encode(img, 'dwtDct')
72
+ img = Image.fromarray(img[:, :, ::-1])
73
+ return img
74
+
75
+
76
+ def load_replacement(x):
77
+ try:
78
+ hwc = x.shape
79
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
80
+ y = (np.array(y)/255.0).astype(x.dtype)
81
+ assert y.shape == x.shape
82
+ return y
83
+ except Exception:
84
+ return x
85
+
86
+
87
+ def check_safety(x_image):
88
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
89
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
90
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
91
+ for i in range(len(has_nsfw_concept)):
92
+ if has_nsfw_concept[i]:
93
+ x_checked_image[i] = load_replacement(x_checked_image[i])
94
+ return x_checked_image, has_nsfw_concept
95
+
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser()
99
+
100
+ parser.add_argument(
101
+ "--prompt",
102
+ type=str,
103
+ nargs="?",
104
+ default="a painting of a virus monster playing guitar",
105
+ help="the prompt to render"
106
+ )
107
+ parser.add_argument(
108
+ "--outdir",
109
+ type=str,
110
+ nargs="?",
111
+ help="dir to write results to",
112
+ default="outputs/txt2img-samples"
113
+ )
114
+ parser.add_argument(
115
+ "--skip_grid",
116
+ action='store_true',
117
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
118
+ )
119
+ parser.add_argument(
120
+ "--skip_save",
121
+ action='store_true',
122
+ help="do not save individual samples. For speed measurements.",
123
+ )
124
+ parser.add_argument(
125
+ "--ddim_steps",
126
+ type=int,
127
+ default=50,
128
+ help="number of ddim sampling steps",
129
+ )
130
+ parser.add_argument(
131
+ "--plms",
132
+ action='store_true',
133
+ help="use plms sampling",
134
+ )
135
+ parser.add_argument(
136
+ "--laion400m",
137
+ action='store_true',
138
+ help="uses the LAION400M model",
139
+ )
140
+ parser.add_argument(
141
+ "--fixed_code",
142
+ action='store_true',
143
+ help="if enabled, uses the same starting code across samples ",
144
+ )
145
+ parser.add_argument(
146
+ "--ddim_eta",
147
+ type=float,
148
+ default=0.0,
149
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
150
+ )
151
+ parser.add_argument(
152
+ "--n_iter",
153
+ type=int,
154
+ default=2,
155
+ help="sample this often",
156
+ )
157
+ parser.add_argument(
158
+ "--H",
159
+ type=int,
160
+ default=512,
161
+ help="image height, in pixel space",
162
+ )
163
+ parser.add_argument(
164
+ "--W",
165
+ type=int,
166
+ default=512,
167
+ help="image width, in pixel space",
168
+ )
169
+ parser.add_argument(
170
+ "--C",
171
+ type=int,
172
+ default=4,
173
+ help="latent channels",
174
+ )
175
+ parser.add_argument(
176
+ "--f",
177
+ type=int,
178
+ default=8,
179
+ help="downsampling factor",
180
+ )
181
+ parser.add_argument(
182
+ "--n_samples",
183
+ type=int,
184
+ default=3,
185
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
186
+ )
187
+ parser.add_argument(
188
+ "--n_rows",
189
+ type=int,
190
+ default=0,
191
+ help="rows in the grid (default: n_samples)",
192
+ )
193
+ parser.add_argument(
194
+ "--scale",
195
+ type=float,
196
+ default=7.5,
197
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
198
+ )
199
+ parser.add_argument(
200
+ "--from-file",
201
+ type=str,
202
+ help="if specified, load prompts from this file",
203
+ )
204
+ parser.add_argument(
205
+ "--config",
206
+ type=str,
207
+ default="configs/stable-diffusion/v1-inference.yaml",
208
+ help="path to config which constructs model",
209
+ )
210
+ parser.add_argument(
211
+ "--ckpt",
212
+ type=str,
213
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
214
+ help="path to checkpoint of model",
215
+ )
216
+ parser.add_argument(
217
+ "--seed",
218
+ type=int,
219
+ default=42,
220
+ help="the seed (for reproducible sampling)",
221
+ )
222
+ parser.add_argument(
223
+ "--precision",
224
+ type=str,
225
+ help="evaluate at this precision",
226
+ choices=["full", "autocast"],
227
+ default="autocast"
228
+ )
229
+ opt = parser.parse_args()
230
+
231
+ if opt.laion400m:
232
+ print("Falling back to LAION 400M model...")
233
+ opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
234
+ opt.ckpt = "models/ldm/text2img-large/model.ckpt"
235
+ opt.outdir = "outputs/txt2img-samples-laion400m"
236
+
237
+ seed_everything(opt.seed)
238
+
239
+ config = OmegaConf.load(f"{opt.config}")
240
+ model = load_model_from_config(config, f"{opt.ckpt}")
241
+
242
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
243
+ model = model.to(device)
244
+
245
+ if opt.plms:
246
+ sampler = PLMSSampler(model)
247
+ else:
248
+ sampler = DDIMSampler(model)
249
+
250
+ os.makedirs(opt.outdir, exist_ok=True)
251
+ outpath = opt.outdir
252
+
253
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
254
+ wm = "StableDiffusionV1"
255
+ wm_encoder = WatermarkEncoder()
256
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
257
+
258
+ batch_size = opt.n_samples
259
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
260
+ if not opt.from_file:
261
+ prompt = opt.prompt
262
+ assert prompt is not None
263
+ data = [batch_size * [prompt]]
264
+
265
+ else:
266
+ print(f"reading prompts from {opt.from_file}")
267
+ with open(opt.from_file, "r") as f:
268
+ data = f.read().splitlines()
269
+ data = list(chunk(data, batch_size))
270
+
271
+ sample_path = os.path.join(outpath, "samples")
272
+ os.makedirs(sample_path, exist_ok=True)
273
+ base_count = len(os.listdir(sample_path))
274
+ grid_count = len(os.listdir(outpath)) - 1
275
+
276
+ start_code = None
277
+ if opt.fixed_code:
278
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
279
+
280
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
281
+ with torch.no_grad():
282
+ with precision_scope("cuda"):
283
+ with model.ema_scope():
284
+ tic = time.time()
285
+ all_samples = list()
286
+ for n in trange(opt.n_iter, desc="Sampling"):
287
+ for prompts in tqdm(data, desc="data"):
288
+ uc = None
289
+ if opt.scale != 1.0:
290
+ uc = model.get_learned_conditioning(batch_size * [""])
291
+ if isinstance(prompts, tuple):
292
+ prompts = list(prompts)
293
+ c = model.get_learned_conditioning(prompts)
294
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
295
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
296
+ conditioning=c,
297
+ batch_size=opt.n_samples,
298
+ shape=shape,
299
+ verbose=False,
300
+ unconditional_guidance_scale=opt.scale,
301
+ unconditional_conditioning=uc,
302
+ eta=opt.ddim_eta,
303
+ x_T=start_code)
304
+
305
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
306
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
307
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
308
+
309
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
310
+
311
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
312
+
313
+ if not opt.skip_save:
314
+ for x_sample in x_checked_image_torch:
315
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
316
+ img = Image.fromarray(x_sample.astype(np.uint8))
317
+ img = put_watermark(img, wm_encoder)
318
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
319
+ base_count += 1
320
+
321
+ if not opt.skip_grid:
322
+ all_samples.append(x_checked_image_torch)
323
+
324
+ if not opt.skip_grid:
325
+ # additionally, save as grid
326
+ grid = torch.stack(all_samples, 0)
327
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
328
+ grid = make_grid(grid, nrow=n_rows)
329
+
330
+ # to image
331
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
332
+ img = Image.fromarray(grid.astype(np.uint8))
333
+ img = put_watermark(img, wm_encoder)
334
+ img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
335
+ grid_count += 1
336
+
337
+ toc = time.time()
338
+
339
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
340
+ f" \nEnjoy.")
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()