Spaces:
Running
on
Zero
Running
on
Zero
from typing import Union, List | |
import PIL | |
import torch | |
import torchvision.transforms as T | |
from einops import repeat | |
from kandinsky3.model.unet import UNet | |
from kandinsky3.movq import MoVQ | |
from kandinsky3.condition_encoders import T5TextConditionEncoder | |
from kandinsky3.condition_processors import T5TextConditionProcessor | |
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule | |
class Kandinsky3T2IPipeline: | |
def __init__( | |
self, | |
device_map: Union[str, torch.device, dict], | |
dtype_map: Union[str, torch.dtype, dict], | |
unet: UNet, | |
null_embedding: torch.Tensor, | |
t5_processor: T5TextConditionProcessor, | |
t5_encoder: T5TextConditionEncoder, | |
movq: MoVQ, | |
gan: bool, | |
): | |
self.device_map = device_map | |
self.dtype_map = dtype_map | |
self.to_pil = T.ToPILImage() | |
self.unet = unet | |
self.null_embedding = null_embedding | |
self.t5_processor = t5_processor | |
self.t5_encoder = t5_encoder | |
self.movq = movq | |
self.gan = gan | |
def __call__( | |
self, | |
text: str, | |
negative_text: str = None, | |
images_num: int = 1, | |
bs: int = 1, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 3.0, | |
steps: int = 50, | |
eta: float = 1.0 | |
) -> List[PIL.Image.Image]: | |
betas = get_named_beta_schedule('cosine', 1000) | |
base_diffusion = BaseDiffusion(betas, 0.99) | |
times = list(range(999, 0, -1000 // steps)) | |
if self.gan: | |
times = list(range(979, 0, -250)) | |
condition_model_input, negative_condition_model_input = self.t5_processor.encode(text, negative_text) | |
for input_type in condition_model_input: | |
condition_model_input[input_type] = condition_model_input[input_type][None].to( | |
self.device_map['text_encoder'] | |
) | |
if negative_condition_model_input is not None: | |
for input_type in negative_condition_model_input: | |
negative_condition_model_input[input_type] = negative_condition_model_input[input_type][None].to( | |
self.device_map['text_encoder'] | |
) | |
pil_images = [] | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']): | |
context, context_mask = self.t5_encoder(condition_model_input) | |
if negative_condition_model_input is not None: | |
negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input) | |
else: | |
negative_context, negative_context_mask = None, None | |
k, m = images_num // bs, images_num % bs | |
for minibatch in [bs] * k + [m]: | |
if minibatch == 0: | |
continue | |
bs_context = repeat(context, '1 n d -> b n d', b=minibatch) | |
bs_context_mask = repeat(context_mask, '1 n -> b n', b=minibatch) | |
if negative_context is not None: | |
bs_negative_context = repeat(negative_context, '1 n d -> b n d', b=minibatch) | |
bs_negative_context_mask = repeat(negative_context_mask, '1 n -> b n', b=minibatch) | |
else: | |
bs_negative_context, bs_negative_context_mask = None, None | |
with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']): | |
images = base_diffusion.p_sample_loop( | |
self.unet, (minibatch, 4, height // 8, width // 8), times, self.device_map['unet'], | |
bs_context, bs_context_mask, self.null_embedding, guidance_scale, eta, | |
negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask, | |
gan=self.gan | |
) | |
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']): | |
images = torch.cat([self.movq.decode(image) for image in images.chunk(2)]) | |
images = torch.clip((images + 1.) / 2., 0., 1.) | |
for images_chunk in images.chunk(1): | |
pil_images += [self.to_pil(image) for image in images_chunk] | |
return pil_images | |