|
import ml_collections |
|
import torch |
|
import random |
|
import utils |
|
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver |
|
from absl import logging |
|
import einops |
|
import libs.autoencoder |
|
import libs.clip |
|
from torchvision.utils import save_image, make_grid |
|
import torchvision.transforms as standard_transforms |
|
import numpy as np |
|
import clip |
|
from PIL import Image |
|
import time |
|
|
|
from typing import Optional, Union, List, Tuple |
|
|
|
from torch import nn |
|
from transformers import ( |
|
CLIPFeatureExtractor, |
|
CLIPProcessor, |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
CLIPVisionModel, |
|
GPT2LMHeadModel, |
|
GPT2Tokenizer, |
|
) |
|
|
|
from libs.autoencoder import Encoder, Decoder |
|
from libs.clip import AbstractEncoder |
|
from libs.caption_decoder import generate2, generate_beam |
|
|
|
|
|
|
|
|
|
|
|
class TestAutoencoderKL(nn.Module): |
|
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215): |
|
super().__init__() |
|
print(f'Create autoencoder with scale_factor={scale_factor}') |
|
self.encoder = Encoder(**ddconfig) |
|
self.decoder = Decoder(**ddconfig) |
|
assert ddconfig["double_z"] |
|
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) |
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) |
|
self.embed_dim = embed_dim |
|
self.scale_factor = scale_factor |
|
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu')) |
|
assert len(m) == 0 and len(u) == 0 |
|
self.eval() |
|
self.requires_grad_(False) |
|
|
|
def encode_moments(self, x): |
|
h = self.encoder(x) |
|
moments = self.quant_conv(h) |
|
return moments |
|
|
|
def sample(self, moments, noise=None, generator=None, device="cuda"): |
|
mean, logvar = torch.chunk(moments, 2, dim=1) |
|
if noise is None: |
|
noise = randn_tensor(mean.shape, generator=generator) |
|
noise = noise.to(device) |
|
logvar = torch.clamp(logvar, -30.0, 20.0) |
|
std = torch.exp(0.5 * logvar) |
|
z = mean + std * noise |
|
z = self.scale_factor * z |
|
return z |
|
|
|
def get_moment_params(self, moments): |
|
mean, logvar = torch.chunk(moments, 2, dim=1) |
|
return mean, logvar |
|
|
|
def encode(self, x): |
|
moments = self.encode_moments(x) |
|
|
|
|
|
mean, logvar = self.get_moment_params(moments) |
|
return mean |
|
|
|
def decode(self, z): |
|
z = (1. / self.scale_factor) * z |
|
z = self.post_quant_conv(z) |
|
dec = self.decoder(z) |
|
return dec |
|
|
|
def forward(self, inputs, fn): |
|
if fn == 'encode_moments': |
|
return self.encode_moments(inputs) |
|
elif fn == 'encode': |
|
return self.encode(inputs) |
|
elif fn == 'decode': |
|
return self.decode(inputs) |
|
else: |
|
raise NotImplementedError |
|
|
|
def freeze(self): |
|
self.eval() |
|
self.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
def get_test_autoencoder(pretrained_path, scale_factor=0.18215): |
|
ddconfig = dict( |
|
double_z=True, |
|
z_channels=4, |
|
resolution=256, |
|
in_channels=3, |
|
out_ch=3, |
|
ch=128, |
|
ch_mult=[1, 2, 4, 4], |
|
num_res_blocks=2, |
|
attn_resolutions=[], |
|
dropout=0.0 |
|
) |
|
vae_scale_factor = 2 ** (len(ddconfig['ch_mult']) - 1) |
|
return TestAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor), vae_scale_factor |
|
|
|
|
|
|
|
def randn_tensor( |
|
shape: Union[Tuple, List], |
|
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, |
|
device: Optional["torch.device"] = None, |
|
dtype: Optional["torch.dtype"] = None, |
|
layout: Optional["torch.layout"] = None, |
|
): |
|
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When |
|
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor |
|
will always be created on CPU. |
|
""" |
|
|
|
rand_device = device |
|
batch_size = shape[0] |
|
|
|
layout = layout or torch.strided |
|
device = device or torch.device("cpu") |
|
|
|
if generator is not None: |
|
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
|
if gen_device_type != device.type and gen_device_type == "cpu": |
|
rand_device = "cpu" |
|
if device != "mps": |
|
logging.info( |
|
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
|
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
|
f" slighly speed up this function by passing a generator that was created on the {device} device." |
|
) |
|
elif gen_device_type != device.type and gen_device_type == "cuda": |
|
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
|
|
|
if isinstance(generator, list): |
|
shape = (1,) + shape[1:] |
|
latents = [ |
|
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) |
|
for i in range(batch_size) |
|
] |
|
latents = torch.cat(latents, dim=0).to(device) |
|
else: |
|
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
|
|
|
return latents |
|
|
|
|
|
|
|
def prepare_latents( |
|
config, |
|
clip_text_model, |
|
clip_img_model, |
|
clip_img_model_preprocess, |
|
autoencoder, |
|
vae_scale_factor, |
|
device, |
|
): |
|
resolution = config.z_shape[-1] * vae_scale_factor |
|
|
|
latent_device = "cpu" |
|
|
|
latent_torch_device = torch.device(latent_device) |
|
generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed) |
|
|
|
contexts = randn_tensor((config.n_samples, 77, config.clip_text_dim), generator=generator, device=latent_torch_device) |
|
img_contexts = randn_tensor((config.n_samples, config.z_shape[0], config.z_shape[1], config.z_shape[2]), generator=generator, device=latent_torch_device) |
|
clip_imgs = randn_tensor((config.n_samples, 1, config.clip_img_dim), generator=generator, device=latent_torch_device) |
|
|
|
if config.mode in ['t2i', 't2i2t']: |
|
prompts = [ config.prompt ] * config.n_samples |
|
contexts = clip_text_model.encode(prompts) |
|
elif config.mode in ['i2t', 'i2t2i']: |
|
from PIL import Image |
|
img_contexts = [] |
|
clip_imgs = [] |
|
|
|
def get_img_feature(image): |
|
image = np.array(image).astype(np.uint8) |
|
image = utils.center_crop(resolution, resolution, image) |
|
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device)) |
|
|
|
image = (image / 127.5 - 1.0).astype(np.float32) |
|
image = einops.rearrange(image, 'h w c -> 1 c h w') |
|
image = torch.tensor(image, device=device) |
|
logging.info(f'Preprocessed VAE image {image}') |
|
logging.info(f"Preprocessed VAE image shape {image.shape}") |
|
|
|
moments = autoencoder.encode_moments(image) |
|
moment_mean, moment_logvar = autoencoder.get_moment_params(moments) |
|
print(f"Moment dist mean: {moment_mean}") |
|
print(f"Moment dist logvar: {moment_logvar}") |
|
moments = autoencoder.sample(moments, generator=generator, device=device) |
|
|
|
return clip_img_feature, moments |
|
|
|
image = Image.open(config.img).convert('RGB') |
|
clip_img, img_context = get_img_feature(image) |
|
|
|
img_contexts.append(img_context) |
|
clip_imgs.append(clip_img) |
|
img_contexts = img_contexts * config.n_samples |
|
clip_imgs = clip_imgs * config.n_samples |
|
|
|
img_contexts = torch.concat(img_contexts, dim=0) |
|
clip_imgs = torch.stack(clip_imgs, dim=0) |
|
|
|
contexts = contexts.to(device) |
|
img_contexts = img_contexts.to(device) |
|
clip_imgs = clip_imgs.to(device) |
|
return contexts, img_contexts, clip_imgs |
|
|
|
|
|
|
|
|
|
|
|
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): |
|
_betas = ( |
|
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 |
|
) |
|
return _betas.numpy() |
|
|
|
|
|
def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder): |
|
resolution = config.z_shape[-1] * 8 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device) |
|
img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2]) |
|
clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim) |
|
|
|
if config.mode in ['t2i', 't2i2t']: |
|
prompts = [ config.prompt ] * config.n_samples |
|
contexts = clip_text_model.encode(prompts) |
|
|
|
elif config.mode in ['i2t', 'i2t2i']: |
|
from PIL import Image |
|
img_contexts = [] |
|
clip_imgs = [] |
|
|
|
def get_img_feature(image): |
|
image = np.array(image).astype(np.uint8) |
|
image = utils.center_crop(resolution, resolution, image) |
|
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device)) |
|
|
|
image = (image / 127.5 - 1.0).astype(np.float32) |
|
image = einops.rearrange(image, 'h w c -> 1 c h w') |
|
image = torch.tensor(image, device=device) |
|
moments = autoencoder.encode_moments(image) |
|
|
|
return clip_img_feature, moments |
|
|
|
image = Image.open(config.img).convert('RGB') |
|
clip_img, img_context = get_img_feature(image) |
|
|
|
img_contexts.append(img_context) |
|
clip_imgs.append(clip_img) |
|
img_contexts = img_contexts * config.n_samples |
|
clip_imgs = clip_imgs * config.n_samples |
|
|
|
img_contexts = torch.concat(img_contexts, dim=0) |
|
clip_imgs = torch.stack(clip_imgs, dim=0) |
|
|
|
return contexts, img_contexts, clip_imgs |
|
|
|
|
|
def unpreprocess(v): |
|
v = 0.5 * (v + 1.) |
|
v.clamp_(0., 1.) |
|
return v |
|
|
|
|
|
def set_seed(seed: int): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def evaluate(config): |
|
if config.get('benchmark', False): |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
device = config.sample.device |
|
torch_device = torch.device(device) |
|
set_seed(config.seed) |
|
|
|
|
|
|
|
generator = torch.Generator(device=torch_device).manual_seed(config.seed) |
|
|
|
config = ml_collections.FrozenConfigDict(config) |
|
if config.sample.log_dir is not None: |
|
log_filename = config.sample.log_dir + "/" + config.mode + ".txt" |
|
utils.set_logger(log_level=config.sample.log_level, fname=log_filename) |
|
else: |
|
utils.set_logger(log_level=config.sample.log_level) |
|
|
|
_betas = stable_diffusion_beta_schedule() |
|
N = len(_betas) |
|
|
|
nnet = utils.get_nnet(**config.nnet) |
|
logging.info(f'load nnet from {config.nnet_path}') |
|
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu')) |
|
nnet.to(device) |
|
nnet.eval() |
|
|
|
use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i' |
|
if use_caption_decoder: |
|
from libs.caption_decoder import CaptionDecoder |
|
caption_decoder = CaptionDecoder(device=device, **config.caption_decoder) |
|
else: |
|
caption_decoder = None |
|
|
|
clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device) |
|
clip_text_model.eval() |
|
clip_text_model.to(device) |
|
|
|
|
|
autoencoder, vae_scale_factor = get_test_autoencoder(**config.autoencoder) |
|
autoencoder.to(device) |
|
|
|
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False) |
|
|
|
empty_context = clip_text_model.encode([''])[0] |
|
|
|
def split(x): |
|
C, H, W = config.z_shape |
|
z_dim = C * H * W |
|
z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1) |
|
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) |
|
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim) |
|
return z, clip_img |
|
|
|
|
|
def combine(z, clip_img): |
|
z = einops.rearrange(z, 'B C H W -> B (C H W)') |
|
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') |
|
return torch.concat([z, clip_img], dim=-1) |
|
|
|
|
|
def t2i_nnet(x, timesteps, text): |
|
""" |
|
1. calculate the conditional model output |
|
2. calculate unconditional model output |
|
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string |
|
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method |
|
3. return linear combination of conditional output and unconditional output |
|
""" |
|
z, clip_img = split(x) |
|
|
|
t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device) |
|
|
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine(z_out, clip_img_out) |
|
|
|
if config.sample.scale == 0.: |
|
return x_out |
|
|
|
if config.sample.t2i_cfg_mode == 'empty_token': |
|
_empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0)) |
|
if use_caption_decoder: |
|
_empty_context = caption_decoder.encode_prefix(_empty_context) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond) |
|
elif config.sample.t2i_cfg_mode == 'true_uncond': |
|
|
|
text_N = randn_tensor(text.shape, generator=generator, device=torch_device) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond) |
|
else: |
|
raise NotImplementedError |
|
|
|
return x_out + config.sample.scale * (x_out - x_out_uncond) |
|
|
|
|
|
def i_nnet(x, timesteps): |
|
z, clip_img = split(x) |
|
text = torch.randn(x.size(0), 77, config.text_dim, device=device) |
|
t_text = torch.ones_like(timesteps) * N |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine(z_out, clip_img_out) |
|
return x_out |
|
|
|
def t_nnet(x, timesteps): |
|
z = torch.randn(x.size(0), *config.z_shape, device=device) |
|
clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device) |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
return text_out |
|
|
|
def i2t_nnet(x, timesteps, z, clip_img): |
|
""" |
|
1. calculate the conditional model output |
|
2. calculate unconditional model output |
|
3. return linear combination of conditional output and unconditional output |
|
""" |
|
t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device) |
|
|
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps, |
|
data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + config.data_type) |
|
|
|
if config.sample.scale == 0.: |
|
return text_out |
|
|
|
|
|
|
|
z_N = randn_tensor(z.shape, generator=generator, device=torch_device) |
|
clip_img_N = randn_tensor(clip_img.shape, generator=generator, device=torch_device) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
|
|
return text_out + config.sample.scale * (text_out - text_out_uncond) |
|
|
|
def split_joint(x): |
|
C, H, W = config.z_shape |
|
z_dim = C * H * W |
|
z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1) |
|
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) |
|
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim) |
|
text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim) |
|
return z, clip_img, text |
|
|
|
def combine_joint(z, clip_img, text): |
|
z = einops.rearrange(z, 'B C H W -> B (C H W)') |
|
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') |
|
text = einops.rearrange(text, 'B L D -> B (L D)') |
|
return torch.concat([z, clip_img, text], dim=-1) |
|
|
|
def joint_nnet(x, timesteps): |
|
z, clip_img, text = split_joint(x) |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine_joint(z_out, clip_img_out, text_out) |
|
|
|
if config.sample.scale == 0.: |
|
return x_out |
|
|
|
|
|
|
|
|
|
z_noise = randn_tensor((x.size(0), *config.z_shape), generator=generator, device=torch_device) |
|
clip_img_noise = randn_tensor((x.size(0), 1, config.clip_img_dim), generator=generator, device=torch_device) |
|
text_noise = randn_tensor((x.size(0), 77, config.text_dim), generator=generator, device=torch_device) |
|
|
|
_, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
|
|
x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond) |
|
|
|
return x_out + config.sample.scale * (x_out - x_out_uncond) |
|
|
|
@torch.cuda.amp.autocast() |
|
def encode(_batch): |
|
return autoencoder.encode(_batch) |
|
|
|
@torch.cuda.amp.autocast() |
|
def decode(_batch): |
|
return autoencoder.decode(_batch) |
|
|
|
|
|
logging.info(config.sample) |
|
logging.info(f'N={N}') |
|
|
|
|
|
contexts, img_contexts, clip_imgs = prepare_latents( |
|
config, |
|
clip_text_model, |
|
clip_img_model, |
|
clip_img_model_preprocess, |
|
autoencoder, |
|
vae_scale_factor, |
|
device, |
|
) |
|
|
|
contexts = contexts |
|
contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) |
|
|
|
logging.debug(f"Text latents: {contexts}") |
|
logging.debug(f"Text latents shape: {contexts.shape}") |
|
logging.debug(f"Low dim text latents: {contexts_low_dim}") |
|
logging.debug(f"Low dim text latents shape: {contexts_low_dim.shape}") |
|
|
|
img_contexts = img_contexts |
|
|
|
z_img = img_contexts |
|
clip_imgs = clip_imgs |
|
|
|
logging.debug(f"Encoded image VAE latents: {z_img}") |
|
logging.debug(f"Encoded image VAE latents shape: {z_img.shape}") |
|
logging.debug(f"Encoded image CLIP latents: {clip_imgs}") |
|
logging.debug(f"Encoded image CLIP latents shape: {clip_imgs.shape}") |
|
|
|
if config.mode in ['t2i', 't2i2t']: |
|
_n_samples = contexts_low_dim.size(0) |
|
elif config.mode in ['i2t', 'i2t2i']: |
|
_n_samples = img_contexts.size(0) |
|
else: |
|
_n_samples = config.n_samples |
|
|
|
|
|
def sample_fn(mode, **kwargs): |
|
|
|
|
|
|
|
|
|
_z_init = randn_tensor((_n_samples, *config.z_shape), generator=generator, device=torch_device) |
|
_clip_img_init = randn_tensor((_n_samples, 1, config.clip_img_dim), generator=generator, device=torch_device) |
|
_text_init = randn_tensor((_n_samples, 77, config.text_dim), generator=generator, device=torch_device) |
|
if mode == 'joint': |
|
_x_init = combine_joint(_z_init, _clip_img_init, _text_init) |
|
elif mode in ['t2i', 'i']: |
|
_x_init = combine(_z_init, _clip_img_init) |
|
elif mode in ['i2t', 't']: |
|
_x_init = _text_init |
|
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) |
|
|
|
def model_fn(x, t_continuous): |
|
t = t_continuous * N |
|
if mode == 'joint': |
|
return joint_nnet(x, t) |
|
elif mode == 't2i': |
|
return t2i_nnet(x, t, **kwargs) |
|
elif mode == 'i2t': |
|
return i2t_nnet(x, t, **kwargs) |
|
elif mode == 'i': |
|
return i_nnet(x, t) |
|
elif mode == 't': |
|
return t_nnet(x, t) |
|
|
|
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device): |
|
start_time = time.time() |
|
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) |
|
end_time = time.time() |
|
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s') |
|
|
|
|
|
if mode == 'joint': |
|
_z, _clip_img, _text = split_joint(x) |
|
return _z, _clip_img, _text |
|
elif mode in ['t2i', 'i']: |
|
_z, _clip_img = split(x) |
|
return _z, _clip_img |
|
elif mode in ['i2t', 't']: |
|
return x |
|
|
|
def test_sample_fn(mode, **kwargs): |
|
if mode == 'joint': |
|
_x_init = combine_joint(z_img, clip_imgs, contexts_low_dim) |
|
elif mode in ['t2i', 'i']: |
|
_x_init = combine(z_img, clip_imgs) |
|
elif mode in ['i2t', 't']: |
|
_x_init = contexts_low_dim |
|
|
|
logging.debug(f"Latents: {_x_init}") |
|
logging.debug(f"Latents shape: {_x_init.shape}") |
|
|
|
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) |
|
|
|
def model_fn(x, t_continuous): |
|
t = t_continuous * N |
|
if mode == 'joint': |
|
noise_pred = joint_nnet(x, t) |
|
logging.debug(f"Joint noise pred for time {t}: {noise_pred}") |
|
logging.debug(f"Joint noise pred for time {t} shape: {noise_pred.shape}") |
|
return noise_pred |
|
|
|
elif mode == 't2i': |
|
noise_pred = t2i_nnet(x, t, **kwargs) |
|
logging.debug(f"t2i noise pred for time {t}: {noise_pred}") |
|
logging.debug(f"t2i noise pred for time {t} shape: {noise_pred.shape}") |
|
return noise_pred |
|
|
|
elif mode == 'i2t': |
|
noise_pred = i2t_nnet(x, t, **kwargs) |
|
logging.debug(f"i2t noise pred for time {t}: {noise_pred}") |
|
logging.debug(f"i2t noise pred for time {t} shape: {noise_pred.shape}") |
|
return noise_pred |
|
|
|
elif mode == 'i': |
|
return i_nnet(x, t) |
|
elif mode == 't': |
|
return t_nnet(x, t) |
|
|
|
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device): |
|
start_time = time.time() |
|
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) |
|
end_time = time.time() |
|
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s') |
|
|
|
logging.debug(f"Full UNet sample: {x}") |
|
logging.debug(f"Full UNet sample shape: {x.shape}") |
|
|
|
|
|
if mode == 'joint': |
|
_z, _clip_img, _text = split_joint(x) |
|
return _z, _clip_img, _text |
|
elif mode in ['t2i', 'i']: |
|
_z, _clip_img = split(x) |
|
return _z, _clip_img |
|
elif mode in ['i2t', 't']: |
|
return x |
|
|
|
output_images = None |
|
output_text = None |
|
|
|
if config.mode in ['joint']: |
|
|
|
_z, _clip_img, _text = test_sample_fn(config.mode) |
|
|
|
logging.debug(f"Text output: {_text}") |
|
logging.debug(f"Text output shape: {_text.shape}") |
|
logging.debug(f"VAE output: {_z}") |
|
logging.debug(f"VAE output shape: {_z.shape}") |
|
logging.debug(f"CLIP output: {_clip_img}") |
|
logging.debug(f"CLIP output shape: {_clip_img.shape}") |
|
|
|
samples = unpreprocess(decode(_z)) |
|
|
|
logging.debug(f"VAE decoded sample: {samples}") |
|
logging.debug(f"VAE decoded sample shape: {samples.shape}") |
|
|
|
prompts = caption_decoder.generate_captions(_text) |
|
|
|
logging.debug(f"Generated text: {prompts}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_images = samples |
|
output_text = prompts |
|
|
|
elif config.mode in ['t2i', 'i', 'i2t2i']: |
|
if config.mode == 't2i': |
|
|
|
_z, _clip_img = test_sample_fn(config.mode, text=contexts_low_dim) |
|
elif config.mode == 'i': |
|
|
|
_z, _clip_img = test_sample_fn(config.mode) |
|
elif config.mode == 'i2t2i': |
|
_text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) |
|
_z, _clip_img = sample_fn('t2i', text=_text) |
|
samples = unpreprocess(decode(_z)) |
|
|
|
|
|
|
|
output_images = samples |
|
|
|
|
|
elif config.mode in ['i2t', 't', 't2i2t']: |
|
if config.mode == 'i2t': |
|
|
|
_text = test_sample_fn(config.mode, z=z_img, clip_img=clip_imgs) |
|
elif config.mode == 't': |
|
|
|
_text = test_sample_fn(config.mode) |
|
elif config.mode == 't2i2t': |
|
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim) |
|
_text = sample_fn('i2t', z=_z, clip_img=_clip_img) |
|
samples = caption_decoder.generate_captions(_text) |
|
logging.info(samples) |
|
output_text = samples |
|
|
|
print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') |
|
|
|
|
|
|
|
if output_images is not None: |
|
for sample in output_images: |
|
sample = standard_transforms.ToPILImage()(sample) |
|
|
|
return output_images, output_text |
|
|
|
|
|
def d(**kwargs): |
|
"""Helper of creating a config dict.""" |
|
return ml_collections.ConfigDict(initial_dictionary=kwargs) |
|
|
|
|
|
def get_config(): |
|
config = ml_collections.ConfigDict() |
|
|
|
config.seed = 0 |
|
config.pred = 'noise_pred' |
|
config.z_shape = (4, 64, 64) |
|
config.clip_img_dim = 512 |
|
config.clip_text_dim = 768 |
|
config.text_dim = 64 |
|
config.data_type = 1 |
|
|
|
config.autoencoder = d( |
|
pretrained_path='models/autoencoder_kl.pth', |
|
) |
|
|
|
config.caption_decoder = d( |
|
pretrained_path="models/caption_decoder.pth", |
|
hidden_dim=config.get_ref('text_dim') |
|
) |
|
|
|
config.nnet = d( |
|
name='uvit_multi_post_ln_v1', |
|
img_size=64, |
|
in_chans=4, |
|
patch_size=2, |
|
embed_dim=1536, |
|
depth=30, |
|
num_heads=24, |
|
mlp_ratio=4, |
|
qkv_bias=False, |
|
pos_drop_rate=0., |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
mlp_time_embed=False, |
|
text_dim=config.get_ref('text_dim'), |
|
num_text_tokens=77, |
|
clip_img_dim=config.get_ref('clip_img_dim'), |
|
use_checkpoint=True |
|
) |
|
|
|
config.sample = d( |
|
sample_steps=3, |
|
scale=7., |
|
t2i_cfg_mode='true_uncond', |
|
device="cuda", |
|
log_level="debug", |
|
log_dir=None, |
|
) |
|
|
|
return config |
|
|
|
|
|
def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None): |
|
config = get_config() |
|
|
|
config.nnet_path = "models/uvit_v1.pth" |
|
config.n_samples = 1 |
|
config.nrow = 1 |
|
|
|
config.mode = mode |
|
config.prompt = prompt |
|
config.img = image |
|
|
|
config.sample.sample_steps = sample_steps |
|
config.sample.scale = scale |
|
if seed is not None: |
|
config.seed = seed |
|
|
|
sample_images, sample_text = evaluate(config) |
|
return sample_images, sample_text |