from transformers import AutoTokenizer from PIL import Image import cv2 import torch from omegaconf import OmegaConf import math from copy import deepcopy import torch.nn.functional as F import numpy as np import clip from transformers import AutoTokenizer from kandinsky2.model.text_encoders import TextEncoder from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ from kandinsky2.model.samplers import DDIMSampler, PLMSSampler from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask class Kandinsky2_1: def __init__( self, config, model_path, prior_path, device, task_type="text2img" ): self.config = config self.device = device self.use_fp16 = self.config["model_config"]["use_fp16"] self.task_type = task_type self.clip_image_size = config["clip_image_size"] if task_type == "text2img": self.config["model_config"]["up"] = False self.config["model_config"]["inpainting"] = False elif task_type == "inpainting": self.config["model_config"]["up"] = False self.config["model_config"]["inpainting"] = True else: raise ValueError("Only text2img and inpainting is available") self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"]) self.tokenizer2 = CustomizedTokenizer() clip_mean, clip_std = torch.load( config["prior"]["clip_mean_std_path"], map_location="cpu" ) self.prior = PriorDiffusionModel( config["prior"]["params"], self.tokenizer2, clip_mean, clip_std, ) self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False) if self.use_fp16: self.prior = self.prior.half() self.text_encoder = TextEncoder(**self.config["text_enc_params"]) if self.use_fp16: self.text_encoder = self.text_encoder.half() self.clip_model, self.preprocess = clip.load( config["clip_name"], device=self.device, jit=False ) self.clip_model.eval() if self.config["image_enc_params"] is not None: self.use_image_enc = True self.scale = self.config["image_enc_params"]["scale"] if self.config["image_enc_params"]["name"] == "AutoencoderKL": self.image_encoder = AutoencoderKL( **self.config["image_enc_params"]["params"] ) elif self.config["image_enc_params"]["name"] == "VQModelInterface": self.image_encoder = VQModelInterface( **self.config["image_enc_params"]["params"] ) elif self.config["image_enc_params"]["name"] == "MOVQ": self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"]) self.image_encoder.load_state_dict( torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu') ) self.image_encoder.eval() else: self.use_image_enc = False self.config["model_config"]["cache_text_emb"] = True self.model = create_model(**self.config["model_config"]) self.model.load_state_dict(torch.load(model_path, map_location='cpu')) if self.use_fp16: self.model.convert_to_fp16() self.image_encoder = self.image_encoder.half() self.model_dtype = torch.float16 else: self.model_dtype = torch.float32 self.image_encoder = self.image_encoder.to(self.device).eval() self.text_encoder = self.text_encoder.to(self.device).eval() self.prior = self.prior.to(self.device).eval() self.model.eval() self.model.to(self.device) def get_new_h_w(self, h, w): new_h = h // 64 if h % 64 != 0: new_h += 1 new_w = w // 64 if w % 64 != 0: new_w += 1 return new_h * 8, new_w * 8 @torch.no_grad() def encode_text(self, text_encoder, tokenizer, prompt, batch_size): text_encoding = tokenizer( [prompt] * batch_size + [""] * batch_size, max_length=77, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) tokens = text_encoding["input_ids"].to(self.device) mask = text_encoding["attention_mask"].to(self.device) full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask) return full_emb, pooled_emb @torch.no_grad() def generate_clip_emb( self, prompt, batch_size=1, prior_cf_scale=4, prior_steps="25", negative_prior_prompt="", ): prompts_batch = [prompt for _ in range(batch_size)] prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch) prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device) max_txt_length = self.prior.model.text_ctx tok, mask = self.tokenizer2.padded_tokens_and_mask( prompts_batch, max_txt_length ) cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask( [negative_prior_prompt], max_txt_length ) if not (cf_token.shape == tok.shape): cf_token = cf_token.expand(tok.shape[0], -1) cf_mask = cf_mask.expand(tok.shape[0], -1) tok = torch.cat([tok, cf_token], dim=0) mask = torch.cat([mask, cf_mask], dim=0) tok, mask = tok.to(device=self.device), mask.to(device=self.device) x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype) x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype) x = x.permute(1, 0, 2) # NLD -> LND| x = self.clip_model.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.clip_model.ln_final(x).type(self.clip_model.dtype) txt_feat_seq = x txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection) txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device) img_feat = self.prior( txt_feat, txt_feat_seq, mask, prior_cf_scales_batch, timestep_respacing=prior_steps, ) return img_feat.to(self.model_dtype) @torch.no_grad() def encode_images(self, image, is_pil=False): if is_pil: image = self.preprocess(image).unsqueeze(0).to(self.device) return self.clip_model.encode_image(image).to(self.model_dtype) @torch.no_grad() def generate_img( self, prompt, img_prompt, batch_size=1, diffusion=None, guidance_scale=7, init_step=None, noise=None, init_img=None, img_mask=None, h=512, w=512, sampler="ddim_sampler", num_steps=50, ): new_h, new_w = self.get_new_h_w(h, w) full_batch_size = batch_size * 2 model_kwargs = {} if init_img is not None and self.use_fp16: init_img = init_img.half() if img_mask is not None and self.use_fp16: img_mask = img_mask.half() model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text( text_encoder=self.text_encoder, tokenizer=self.tokenizer1, prompt=prompt, batch_size=batch_size, ) model_kwargs["image_emb"] = img_prompt if self.task_type == "inpainting": init_img = init_img.to(self.device) img_mask = img_mask.to(self.device) model_kwargs["inpaint_image"] = init_img * img_mask model_kwargs["inpaint_mask"] = img_mask def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // 2] combined = torch.cat([half, half], dim=0) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :4], model_out[:, 4:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) if sampler == "p_sampler": return torch.cat([eps, rest], dim=1) else: return eps if noise is not None: noise = noise.float() if self.task_type == "inpainting": def denoised_fun(x_start): x_start = x_start.clamp(-2, 2) return x_start * (1 - img_mask) + init_img * img_mask else: def denoised_fun(x): return x.clamp(-2, 2) if sampler == "p_sampler": self.model.del_cache() samples = diffusion.p_sample_loop( model_fn, (full_batch_size, 4, new_h, new_w), device=self.device, noise=noise, progress=True, model_kwargs=model_kwargs, init_step=init_step, denoised_fn=denoised_fun, )[:batch_size] self.model.del_cache() else: if sampler == "ddim_sampler": sampler = DDIMSampler( model=model_fn, old_diffusion=diffusion, schedule="linear", ) elif sampler == "plms_sampler": sampler = PLMSSampler( model=model_fn, old_diffusion=diffusion, schedule="linear", ) else: raise ValueError("Only ddim_sampler and plms_sampler is available") self.model.del_cache() samples, _ = sampler.sample( num_steps, batch_size * 2, (4, new_h, new_w), conditioning=model_kwargs, x_T=noise, init_step=init_step, ) self.model.del_cache() samples = samples[:batch_size] if self.use_image_enc: if self.use_fp16: samples = samples.half() samples = self.image_encoder.decode(samples / self.scale) samples = samples[:, :, :h, :w] return process_images(samples) @torch.no_grad() def create_zero_img_emb(self, batch_size): img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device) return self.encode_images(img, is_pil=False).repeat(batch_size, 1) @torch.no_grad() def generate_text2img( self, prompt, num_steps=100, batch_size=1, guidance_scale=7, h=512, w=512, sampler="ddim_sampler", prior_cf_scale=4, prior_steps="25", negative_prior_prompt="", negative_decoder_prompt="", ): # generate clip embeddings image_emb = self.generate_clip_emb( prompt, batch_size=batch_size, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) if negative_decoder_prompt == "": zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) else: zero_image_emb = self.generate_clip_emb( negative_decoder_prompt, batch_size=batch_size, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) return self.generate_img( prompt=prompt, img_prompt=image_emb, batch_size=batch_size, guidance_scale=guidance_scale, h=h, w=w, sampler=sampler, num_steps=num_steps, diffusion=diffusion, ) @torch.no_grad() def mix_images( self, images_texts, weights, num_steps=100, batch_size=1, guidance_scale=7, h=512, w=512, sampler="ddim_sampler", prior_cf_scale=4, prior_steps="25", negative_prior_prompt="", negative_decoder_prompt="", ): assert len(images_texts) == len(weights) and len(images_texts) > 0 # generate clip embeddings image_emb = None for i in range(len(images_texts)): if image_emb is None: if type(images_texts[i]) == str: image_emb = weights[i] * self.generate_clip_emb( images_texts[i], batch_size=1, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) else: image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i] else: if type(images_texts[i]) == str: image_emb = image_emb + weights[i] * self.generate_clip_emb( images_texts[i], batch_size=1, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) else: image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i] image_emb = image_emb.repeat(batch_size, 1) if negative_decoder_prompt == "": zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) else: zero_image_emb = self.generate_clip_emb( negative_decoder_prompt, batch_size=batch_size, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) return self.generate_img( prompt="", img_prompt=image_emb, batch_size=batch_size, guidance_scale=guidance_scale, h=h, w=w, sampler=sampler, num_steps=num_steps, diffusion=diffusion, ) @torch.no_grad() def generate_img2img( self, prompt, pil_img, strength=0.7, num_steps=100, batch_size=1, guidance_scale=7, h=512, w=512, sampler="ddim_sampler", prior_cf_scale=4, prior_steps="25", ): # generate clip embeddings image_emb = self.generate_clip_emb( prompt, batch_size=batch_size, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, ) zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) image = prepare_image(pil_img, h=h, w=w).to(self.device) if self.use_fp16: image = image.half() image = self.image_encoder.encode(image) * self.scale start_step = int(diffusion.num_timesteps * (1 - strength)) image = q_sample( image, torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device), schedule_name=config["diffusion_config"]["noise_schedule"], num_steps=config["diffusion_config"]["steps"], ) image = image.repeat(2, 1, 1, 1) return self.generate_img( prompt=prompt, img_prompt=image_emb, batch_size=batch_size, guidance_scale=guidance_scale, h=h, w=w, sampler=sampler, num_steps=num_steps, diffusion=diffusion, noise=image, init_step=start_step, ) @torch.no_grad() def generate_inpainting( self, prompt, pil_img, img_mask, num_steps=100, batch_size=1, guidance_scale=7, h=512, w=512, sampler="ddim_sampler", prior_cf_scale=4, prior_steps="25", negative_prior_prompt="", negative_decoder_prompt="", ): # generate clip embeddings image_emb = self.generate_clip_emb( prompt, batch_size=batch_size, prior_cf_scale=prior_cf_scale, prior_steps=prior_steps, negative_prior_prompt=negative_prior_prompt, ) zero_image_emb = self.create_zero_img_emb(batch_size=batch_size) image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device) # load diffusion config = deepcopy(self.config) if sampler == "p_sampler": config["diffusion_config"]["timestep_respacing"] = str(num_steps) diffusion = create_gaussian_diffusion(**config["diffusion_config"]) image = prepare_image(pil_img, w, h).to(self.device) if self.use_fp16: image = image.half() image = self.image_encoder.encode(image) * self.scale image_shape = tuple(image.shape[-2:]) img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0) img_mask = F.interpolate( img_mask, image_shape, mode="nearest", ) img_mask = prepare_mask(img_mask).to(self.device) if self.use_fp16: img_mask = img_mask.half() image = image.repeat(2, 1, 1, 1) img_mask = img_mask.repeat(2, 1, 1, 1) return self.generate_img( prompt=prompt, img_prompt=image_emb, batch_size=batch_size, guidance_scale=guidance_scale, h=h, w=w, sampler=sampler, num_steps=num_steps, diffusion=diffusion, init_img=image, img_mask=img_mask, ) import os from huggingface_hub import hf_hub_url, cached_download from copy import deepcopy from omegaconf.dictconfig import DictConfig def get_kandinsky2_1( device, task_type="text2img", cache_dir="/tmp/kandinsky2", use_auth_token=None, use_flash_attention=False, ): cache_dir = os.path.join(cache_dir, "2_1") config = DictConfig(deepcopy(CONFIG_2_1)) config["model_config"]["use_flash_attention"] = use_flash_attention if task_type == "text2img": model_name = "decoder_fp16.ckpt" config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) elif task_type == "inpainting": model_name = "inpainting_fp16.ckpt" config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name) cached_download( config_file_url, cache_dir=cache_dir, force_filename=model_name, use_auth_token=use_auth_token, ) prior_name = "prior_fp16.ckpt" config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name) cached_download( config_file_url, cache_dir=cache_dir, force_filename=prior_name, use_auth_token=use_auth_token, ) cache_dir_text_en = os.path.join(cache_dir, "text_encoder") for name in [ "config.json", "pytorch_model.bin", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", ]: config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}") cached_download( config_file_url, cache_dir=cache_dir_text_en, force_filename=name, use_auth_token=use_auth_token, ) config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt") cached_download( config_file_url, cache_dir=cache_dir, force_filename="movq_final.ckpt", use_auth_token=use_auth_token, ) config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th") cached_download( config_file_url, cache_dir=cache_dir, force_filename="ViT-L-14_stats.th", use_auth_token=use_auth_token, ) config["tokenizer_name"] = cache_dir_text_en config["text_enc_params"]["model_path"] = cache_dir_text_en config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th") config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt") cache_model_name = os.path.join(cache_dir, model_name) cache_prior_name = os.path.join(cache_dir, prior_name) model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type) return model def get_kandinsky2( device, task_type="text2img", cache_dir="/tmp/kandinsky2", use_auth_token=None, model_version="2.1", use_flash_attention=False, ): if model_version == "2.0": model = get_kandinsky2_0( device, task_type=task_type, cache_dir=cache_dir, use_auth_token=use_auth_token, ) elif model_version == "2.1": model = get_kandinsky2_1( device, task_type=task_type, cache_dir=cache_dir, use_auth_token=use_auth_token, use_flash_attention=use_flash_attention, ) elif model_version == "2.2": model = Kandinsky2_2(device=device, task_type=task_type) else: raise ValueError("Only 2.0 and 2.1 is available") return model