ai-forever commited on
Commit
78a6221
β€’
1 Parent(s): f179346

Create kandinsky2_1_model.py

Browse files
Files changed (1) hide show
  1. kandinsky2_1_model.py +656 -0
kandinsky2_1_model.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from PIL import Image
3
+ import cv2
4
+ import torch
5
+ from omegaconf import OmegaConf
6
+ import math
7
+ from copy import deepcopy
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import clip
11
+ from transformers import AutoTokenizer
12
+
13
+ from kandinsky2.model.text_encoders import TextEncoder
14
+ from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ
15
+ from kandinsky2.model.samplers import DDIMSampler, PLMSSampler
16
+ from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion
17
+ from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer
18
+ from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask
19
+
20
+
21
+ class Kandinsky2_1:
22
+
23
+ def __init__(
24
+ self,
25
+ config,
26
+ model_path,
27
+ prior_path,
28
+ device,
29
+ task_type="text2img"
30
+ ):
31
+ self.config = config
32
+ self.device = device
33
+ self.use_fp16 = self.config["model_config"]["use_fp16"]
34
+ self.task_type = task_type
35
+ self.clip_image_size = config["clip_image_size"]
36
+ if task_type == "text2img":
37
+ self.config["model_config"]["up"] = False
38
+ self.config["model_config"]["inpainting"] = False
39
+ elif task_type == "inpainting":
40
+ self.config["model_config"]["up"] = False
41
+ self.config["model_config"]["inpainting"] = True
42
+ else:
43
+ raise ValueError("Only text2img and inpainting is available")
44
+
45
+ self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"])
46
+ self.tokenizer2 = CustomizedTokenizer()
47
+ clip_mean, clip_std = torch.load(
48
+ config["prior"]["clip_mean_std_path"], map_location="cpu"
49
+ )
50
+
51
+ self.prior = PriorDiffusionModel(
52
+ config["prior"]["params"],
53
+ self.tokenizer2,
54
+ clip_mean,
55
+ clip_std,
56
+ )
57
+ self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
58
+ if self.use_fp16:
59
+ self.prior = self.prior.half()
60
+ self.text_encoder = TextEncoder(**self.config["text_enc_params"])
61
+ if self.use_fp16:
62
+ self.text_encoder = self.text_encoder.half()
63
+
64
+ self.clip_model, self.preprocess = clip.load(
65
+ config["clip_name"], device=self.device, jit=False
66
+ )
67
+ self.clip_model.eval()
68
+
69
+ if self.config["image_enc_params"] is not None:
70
+ self.use_image_enc = True
71
+ self.scale = self.config["image_enc_params"]["scale"]
72
+ if self.config["image_enc_params"]["name"] == "AutoencoderKL":
73
+ self.image_encoder = AutoencoderKL(
74
+ **self.config["image_enc_params"]["params"]
75
+ )
76
+ elif self.config["image_enc_params"]["name"] == "VQModelInterface":
77
+ self.image_encoder = VQModelInterface(
78
+ **self.config["image_enc_params"]["params"]
79
+ )
80
+ elif self.config["image_enc_params"]["name"] == "MOVQ":
81
+ self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"])
82
+ self.image_encoder.load_state_dict(
83
+ torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu')
84
+ )
85
+ self.image_encoder.eval()
86
+ else:
87
+ self.use_image_enc = False
88
+
89
+ self.config["model_config"]["cache_text_emb"] = True
90
+ self.model = create_model(**self.config["model_config"])
91
+ self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
92
+ if self.use_fp16:
93
+ self.model.convert_to_fp16()
94
+ self.image_encoder = self.image_encoder.half()
95
+
96
+ self.model_dtype = torch.float16
97
+ else:
98
+ self.model_dtype = torch.float32
99
+
100
+ self.image_encoder = self.image_encoder.to(self.device).eval()
101
+ self.text_encoder = self.text_encoder.to(self.device).eval()
102
+ self.prior = self.prior.to(self.device).eval()
103
+ self.model.eval()
104
+ self.model.to(self.device)
105
+
106
+ def get_new_h_w(self, h, w):
107
+ new_h = h // 64
108
+ if h % 64 != 0:
109
+ new_h += 1
110
+ new_w = w // 64
111
+ if w % 64 != 0:
112
+ new_w += 1
113
+ return new_h * 8, new_w * 8
114
+
115
+ @torch.no_grad()
116
+ def encode_text(self, text_encoder, tokenizer, prompt, batch_size):
117
+ text_encoding = tokenizer(
118
+ [prompt] * batch_size + [""] * batch_size,
119
+ max_length=77,
120
+ padding="max_length",
121
+ truncation=True,
122
+ return_attention_mask=True,
123
+ add_special_tokens=True,
124
+ return_tensors="pt",
125
+ )
126
+
127
+ tokens = text_encoding["input_ids"].to(self.device)
128
+ mask = text_encoding["attention_mask"].to(self.device)
129
+
130
+ full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask)
131
+ return full_emb, pooled_emb
132
+
133
+ @torch.no_grad()
134
+ def generate_clip_emb(
135
+ self,
136
+ prompt,
137
+ batch_size=1,
138
+ prior_cf_scale=4,
139
+ prior_steps="25",
140
+ negative_prior_prompt="",
141
+ ):
142
+ prompts_batch = [prompt for _ in range(batch_size)]
143
+ prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch)
144
+ prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
145
+ max_txt_length = self.prior.model.text_ctx
146
+ tok, mask = self.tokenizer2.padded_tokens_and_mask(
147
+ prompts_batch, max_txt_length
148
+ )
149
+ cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask(
150
+ [negative_prior_prompt], max_txt_length
151
+ )
152
+ if not (cf_token.shape == tok.shape):
153
+ cf_token = cf_token.expand(tok.shape[0], -1)
154
+ cf_mask = cf_mask.expand(tok.shape[0], -1)
155
+ tok = torch.cat([tok, cf_token], dim=0)
156
+ mask = torch.cat([mask, cf_mask], dim=0)
157
+ tok, mask = tok.to(device=self.device), mask.to(device=self.device)
158
+
159
+ x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype)
160
+ x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
161
+ x = x.permute(1, 0, 2) # NLD -> LND|
162
+ x = self.clip_model.transformer(x)
163
+ x = x.permute(1, 0, 2) # LND -> NLD
164
+ x = self.clip_model.ln_final(x).type(self.clip_model.dtype)
165
+ txt_feat_seq = x
166
+ txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection)
167
+ txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device)
168
+ img_feat = self.prior(
169
+ txt_feat,
170
+ txt_feat_seq,
171
+ mask,
172
+ prior_cf_scales_batch,
173
+ timestep_respacing=prior_steps,
174
+ )
175
+ return img_feat.to(self.model_dtype)
176
+
177
+ @torch.no_grad()
178
+ def encode_images(self, image, is_pil=False):
179
+ if is_pil:
180
+ image = self.preprocess(image).unsqueeze(0).to(self.device)
181
+ return self.clip_model.encode_image(image).to(self.model_dtype)
182
+
183
+ @torch.no_grad()
184
+ def generate_img(
185
+ self,
186
+ prompt,
187
+ img_prompt,
188
+ batch_size=1,
189
+ diffusion=None,
190
+ guidance_scale=7,
191
+ init_step=None,
192
+ noise=None,
193
+ init_img=None,
194
+ img_mask=None,
195
+ h=512,
196
+ w=512,
197
+ sampler="ddim_sampler",
198
+ num_steps=50,
199
+ ):
200
+ new_h, new_w = self.get_new_h_w(h, w)
201
+ full_batch_size = batch_size * 2
202
+ model_kwargs = {}
203
+
204
+ if init_img is not None and self.use_fp16:
205
+ init_img = init_img.half()
206
+ if img_mask is not None and self.use_fp16:
207
+ img_mask = img_mask.half()
208
+ model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text(
209
+ text_encoder=self.text_encoder,
210
+ tokenizer=self.tokenizer1,
211
+ prompt=prompt,
212
+ batch_size=batch_size,
213
+ )
214
+ model_kwargs["image_emb"] = img_prompt
215
+
216
+ if self.task_type == "inpainting":
217
+ init_img = init_img.to(self.device)
218
+ img_mask = img_mask.to(self.device)
219
+ model_kwargs["inpaint_image"] = init_img * img_mask
220
+ model_kwargs["inpaint_mask"] = img_mask
221
+
222
+ def model_fn(x_t, ts, **kwargs):
223
+ half = x_t[: len(x_t) // 2]
224
+ combined = torch.cat([half, half], dim=0)
225
+ model_out = self.model(combined, ts, **kwargs)
226
+ eps, rest = model_out[:, :4], model_out[:, 4:]
227
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
228
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
229
+ eps = torch.cat([half_eps, half_eps], dim=0)
230
+ if sampler == "p_sampler":
231
+ return torch.cat([eps, rest], dim=1)
232
+ else:
233
+ return eps
234
+
235
+ if noise is not None:
236
+ noise = noise.float()
237
+ if self.task_type == "inpainting":
238
+ def denoised_fun(x_start):
239
+ x_start = x_start.clamp(-2, 2)
240
+ return x_start * (1 - img_mask) + init_img * img_mask
241
+ else:
242
+ def denoised_fun(x):
243
+ return x.clamp(-2, 2)
244
+
245
+ if sampler == "p_sampler":
246
+ self.model.del_cache()
247
+ samples = diffusion.p_sample_loop(
248
+ model_fn,
249
+ (full_batch_size, 4, new_h, new_w),
250
+ device=self.device,
251
+ noise=noise,
252
+ progress=True,
253
+ model_kwargs=model_kwargs,
254
+ init_step=init_step,
255
+ denoised_fn=denoised_fun,
256
+ )[:batch_size]
257
+ self.model.del_cache()
258
+ else:
259
+ if sampler == "ddim_sampler":
260
+ sampler = DDIMSampler(
261
+ model=model_fn,
262
+ old_diffusion=diffusion,
263
+ schedule="linear",
264
+ )
265
+ elif sampler == "plms_sampler":
266
+ sampler = PLMSSampler(
267
+ model=model_fn,
268
+ old_diffusion=diffusion,
269
+ schedule="linear",
270
+ )
271
+ else:
272
+ raise ValueError("Only ddim_sampler and plms_sampler is available")
273
+
274
+ self.model.del_cache()
275
+ samples, _ = sampler.sample(
276
+ num_steps,
277
+ batch_size * 2,
278
+ (4, new_h, new_w),
279
+ conditioning=model_kwargs,
280
+ x_T=noise,
281
+ init_step=init_step,
282
+ )
283
+ self.model.del_cache()
284
+ samples = samples[:batch_size]
285
+
286
+ if self.use_image_enc:
287
+ if self.use_fp16:
288
+ samples = samples.half()
289
+ samples = self.image_encoder.decode(samples / self.scale)
290
+
291
+ samples = samples[:, :, :h, :w]
292
+ return process_images(samples)
293
+
294
+ @torch.no_grad()
295
+ def create_zero_img_emb(self, batch_size):
296
+ img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device)
297
+ return self.encode_images(img, is_pil=False).repeat(batch_size, 1)
298
+
299
+ @torch.no_grad()
300
+ def generate_text2img(
301
+ self,
302
+ prompt,
303
+ num_steps=100,
304
+ batch_size=1,
305
+ guidance_scale=7,
306
+ h=512,
307
+ w=512,
308
+ sampler="ddim_sampler",
309
+ prior_cf_scale=4,
310
+ prior_steps="25",
311
+ negative_prior_prompt="",
312
+ negative_decoder_prompt="",
313
+ ):
314
+ # generate clip embeddings
315
+ image_emb = self.generate_clip_emb(
316
+ prompt,
317
+ batch_size=batch_size,
318
+ prior_cf_scale=prior_cf_scale,
319
+ prior_steps=prior_steps,
320
+ negative_prior_prompt=negative_prior_prompt,
321
+ )
322
+ if negative_decoder_prompt == "":
323
+ zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
324
+ else:
325
+ zero_image_emb = self.generate_clip_emb(
326
+ negative_decoder_prompt,
327
+ batch_size=batch_size,
328
+ prior_cf_scale=prior_cf_scale,
329
+ prior_steps=prior_steps,
330
+ negative_prior_prompt=negative_prior_prompt,
331
+ )
332
+
333
+ image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
334
+
335
+ # load diffusion
336
+ config = deepcopy(self.config)
337
+ if sampler == "p_sampler":
338
+ config["diffusion_config"]["timestep_respacing"] = str(num_steps)
339
+ diffusion = create_gaussian_diffusion(**config["diffusion_config"])
340
+
341
+ return self.generate_img(
342
+ prompt=prompt,
343
+ img_prompt=image_emb,
344
+ batch_size=batch_size,
345
+ guidance_scale=guidance_scale,
346
+ h=h,
347
+ w=w,
348
+ sampler=sampler,
349
+ num_steps=num_steps,
350
+ diffusion=diffusion,
351
+ )
352
+
353
+ @torch.no_grad()
354
+ def mix_images(
355
+ self,
356
+ images_texts,
357
+ weights,
358
+ num_steps=100,
359
+ batch_size=1,
360
+ guidance_scale=7,
361
+ h=512,
362
+ w=512,
363
+ sampler="ddim_sampler",
364
+ prior_cf_scale=4,
365
+ prior_steps="25",
366
+ negative_prior_prompt="",
367
+ negative_decoder_prompt="",
368
+ ):
369
+ assert len(images_texts) == len(weights) and len(images_texts) > 0
370
+
371
+ # generate clip embeddings
372
+ image_emb = None
373
+ for i in range(len(images_texts)):
374
+ if image_emb is None:
375
+ if type(images_texts[i]) == str:
376
+ image_emb = weights[i] * self.generate_clip_emb(
377
+ images_texts[i],
378
+ batch_size=1,
379
+ prior_cf_scale=prior_cf_scale,
380
+ prior_steps=prior_steps,
381
+ negative_prior_prompt=negative_prior_prompt,
382
+ )
383
+ else:
384
+ image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i]
385
+ else:
386
+ if type(images_texts[i]) == str:
387
+ image_emb = image_emb + weights[i] * self.generate_clip_emb(
388
+ images_texts[i],
389
+ batch_size=1,
390
+ prior_cf_scale=prior_cf_scale,
391
+ prior_steps=prior_steps,
392
+ negative_prior_prompt=negative_prior_prompt,
393
+ )
394
+ else:
395
+ image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i]
396
+
397
+ image_emb = image_emb.repeat(batch_size, 1)
398
+ if negative_decoder_prompt == "":
399
+ zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
400
+ else:
401
+ zero_image_emb = self.generate_clip_emb(
402
+ negative_decoder_prompt,
403
+ batch_size=batch_size,
404
+ prior_cf_scale=prior_cf_scale,
405
+ prior_steps=prior_steps,
406
+ negative_prior_prompt=negative_prior_prompt,
407
+ )
408
+ image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
409
+
410
+ # load diffusion
411
+ config = deepcopy(self.config)
412
+ if sampler == "p_sampler":
413
+ config["diffusion_config"]["timestep_respacing"] = str(num_steps)
414
+ diffusion = create_gaussian_diffusion(**config["diffusion_config"])
415
+ return self.generate_img(
416
+ prompt="",
417
+ img_prompt=image_emb,
418
+ batch_size=batch_size,
419
+ guidance_scale=guidance_scale,
420
+ h=h,
421
+ w=w,
422
+ sampler=sampler,
423
+ num_steps=num_steps,
424
+ diffusion=diffusion,
425
+ )
426
+
427
+ @torch.no_grad()
428
+ def generate_img2img(
429
+ self,
430
+ prompt,
431
+ pil_img,
432
+ strength=0.7,
433
+ num_steps=100,
434
+ batch_size=1,
435
+ guidance_scale=7,
436
+ h=512,
437
+ w=512,
438
+ sampler="ddim_sampler",
439
+ prior_cf_scale=4,
440
+ prior_steps="25",
441
+ ):
442
+ # generate clip embeddings
443
+ image_emb = self.generate_clip_emb(
444
+ prompt,
445
+ batch_size=batch_size,
446
+ prior_cf_scale=prior_cf_scale,
447
+ prior_steps=prior_steps,
448
+ )
449
+ zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
450
+ image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
451
+
452
+ # load diffusion
453
+ config = deepcopy(self.config)
454
+ if sampler == "p_sampler":
455
+ config["diffusion_config"]["timestep_respacing"] = str(num_steps)
456
+ diffusion = create_gaussian_diffusion(**config["diffusion_config"])
457
+
458
+ image = prepare_image(pil_img, h=h, w=w).to(self.device)
459
+ if self.use_fp16:
460
+ image = image.half()
461
+ image = self.image_encoder.encode(image) * self.scale
462
+
463
+ start_step = int(diffusion.num_timesteps * (1 - strength))
464
+ image = q_sample(
465
+ image,
466
+ torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device),
467
+ schedule_name=config["diffusion_config"]["noise_schedule"],
468
+ num_steps=config["diffusion_config"]["steps"],
469
+ )
470
+
471
+ image = image.repeat(2, 1, 1, 1)
472
+ return self.generate_img(
473
+ prompt=prompt,
474
+ img_prompt=image_emb,
475
+ batch_size=batch_size,
476
+ guidance_scale=guidance_scale,
477
+ h=h,
478
+ w=w,
479
+ sampler=sampler,
480
+ num_steps=num_steps,
481
+ diffusion=diffusion,
482
+ noise=image,
483
+ init_step=start_step,
484
+ )
485
+
486
+ @torch.no_grad()
487
+ def generate_inpainting(
488
+ self,
489
+ prompt,
490
+ pil_img,
491
+ img_mask,
492
+ num_steps=100,
493
+ batch_size=1,
494
+ guidance_scale=7,
495
+ h=512,
496
+ w=512,
497
+ sampler="ddim_sampler",
498
+ prior_cf_scale=4,
499
+ prior_steps="25",
500
+ negative_prior_prompt="",
501
+ negative_decoder_prompt="",
502
+ ):
503
+ # generate clip embeddings
504
+ image_emb = self.generate_clip_emb(
505
+ prompt,
506
+ batch_size=batch_size,
507
+ prior_cf_scale=prior_cf_scale,
508
+ prior_steps=prior_steps,
509
+ negative_prior_prompt=negative_prior_prompt,
510
+ )
511
+ zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
512
+ image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
513
+
514
+ # load diffusion
515
+ config = deepcopy(self.config)
516
+ if sampler == "p_sampler":
517
+ config["diffusion_config"]["timestep_respacing"] = str(num_steps)
518
+ diffusion = create_gaussian_diffusion(**config["diffusion_config"])
519
+ image = prepare_image(pil_img, w, h).to(self.device)
520
+ if self.use_fp16:
521
+ image = image.half()
522
+ image = self.image_encoder.encode(image) * self.scale
523
+ image_shape = tuple(image.shape[-2:])
524
+ img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0)
525
+ img_mask = F.interpolate(
526
+ img_mask,
527
+ image_shape,
528
+ mode="nearest",
529
+ )
530
+ img_mask = prepare_mask(img_mask).to(self.device)
531
+ if self.use_fp16:
532
+ img_mask = img_mask.half()
533
+ image = image.repeat(2, 1, 1, 1)
534
+ img_mask = img_mask.repeat(2, 1, 1, 1)
535
+
536
+ return self.generate_img(
537
+ prompt=prompt,
538
+ img_prompt=image_emb,
539
+ batch_size=batch_size,
540
+ guidance_scale=guidance_scale,
541
+ h=h,
542
+ w=w,
543
+ sampler=sampler,
544
+ num_steps=num_steps,
545
+ diffusion=diffusion,
546
+ init_img=image,
547
+ img_mask=img_mask,
548
+ )
549
+ import os
550
+ from huggingface_hub import hf_hub_url, cached_download
551
+ from copy import deepcopy
552
+ from omegaconf.dictconfig import DictConfig
553
+
554
+ def get_kandinsky2_1(
555
+ device,
556
+ task_type="text2img",
557
+ cache_dir="/tmp/kandinsky2",
558
+ use_auth_token=None,
559
+ use_flash_attention=False,
560
+ ):
561
+ cache_dir = os.path.join(cache_dir, "2_1")
562
+ config = DictConfig(deepcopy(CONFIG_2_1))
563
+ config["model_config"]["use_flash_attention"] = use_flash_attention
564
+ if task_type == "text2img":
565
+ model_name = "decoder_fp16.ckpt"
566
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
567
+ elif task_type == "inpainting":
568
+ model_name = "inpainting_fp16.ckpt"
569
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
570
+ cached_download(
571
+ config_file_url,
572
+ cache_dir=cache_dir,
573
+ force_filename=model_name,
574
+ use_auth_token=use_auth_token,
575
+ )
576
+ prior_name = "prior_fp16.ckpt"
577
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name)
578
+ cached_download(
579
+ config_file_url,
580
+ cache_dir=cache_dir,
581
+ force_filename=prior_name,
582
+ use_auth_token=use_auth_token,
583
+ )
584
+
585
+ cache_dir_text_en = os.path.join(cache_dir, "text_encoder")
586
+ for name in [
587
+ "config.json",
588
+ "pytorch_model.bin",
589
+ "sentencepiece.bpe.model",
590
+ "special_tokens_map.json",
591
+ "tokenizer.json",
592
+ "tokenizer_config.json",
593
+ ]:
594
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}")
595
+ cached_download(
596
+ config_file_url,
597
+ cache_dir=cache_dir_text_en,
598
+ force_filename=name,
599
+ use_auth_token=use_auth_token,
600
+ )
601
+
602
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt")
603
+ cached_download(
604
+ config_file_url,
605
+ cache_dir=cache_dir,
606
+ force_filename="movq_final.ckpt",
607
+ use_auth_token=use_auth_token,
608
+ )
609
+
610
+ config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th")
611
+ cached_download(
612
+ config_file_url,
613
+ cache_dir=cache_dir,
614
+ force_filename="ViT-L-14_stats.th",
615
+ use_auth_token=use_auth_token,
616
+ )
617
+
618
+ config["tokenizer_name"] = cache_dir_text_en
619
+ config["text_enc_params"]["model_path"] = cache_dir_text_en
620
+ config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
621
+ config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")
622
+ cache_model_name = os.path.join(cache_dir, model_name)
623
+ cache_prior_name = os.path.join(cache_dir, prior_name)
624
+ model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type)
625
+ return model
626
+
627
+
628
+ def get_kandinsky2(
629
+ device,
630
+ task_type="text2img",
631
+ cache_dir="/tmp/kandinsky2",
632
+ use_auth_token=None,
633
+ model_version="2.1",
634
+ use_flash_attention=False,
635
+ ):
636
+ if model_version == "2.0":
637
+ model = get_kandinsky2_0(
638
+ device,
639
+ task_type=task_type,
640
+ cache_dir=cache_dir,
641
+ use_auth_token=use_auth_token,
642
+ )
643
+ elif model_version == "2.1":
644
+ model = get_kandinsky2_1(
645
+ device,
646
+ task_type=task_type,
647
+ cache_dir=cache_dir,
648
+ use_auth_token=use_auth_token,
649
+ use_flash_attention=use_flash_attention,
650
+ )
651
+ elif model_version == "2.2":
652
+ model = Kandinsky2_2(device=device, task_type=task_type)
653
+ else:
654
+ raise ValueError("Only 2.0 and 2.1 is available")
655
+
656
+ return model