import ml_collections
import torch
import random
import utils
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
from absl import logging
import einops
import libs.autoencoder
import libs.clip
from torchvision.utils import save_image, make_grid
import torchvision.transforms as standard_transforms
import numpy as np
import clip
from PIL import Image
import time
from typing import Optional, Union, List, Tuple
from torch import nn
from transformers import (
from libs.autoencoder import Encoder, Decoder
from libs.clip import AbstractEncoder
from libs.caption_decoder import generate2, generate_beam
# ----Define Testing Versions of Classes----
class TestAutoencoderKL(nn.Module):
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
print(f'Create autoencoder with scale_factor={scale_factor}')
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.scale_factor = scale_factor
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
assert len(m) == 0 and len(u) == 0
def encode_moments(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
return moments
def sample(self, moments, noise=None, generator=None, device="cuda"):
mean, logvar = torch.chunk(moments, 2, dim=1)
if noise is None:
# Generate on CPU.
noise = randn_tensor(mean.shape, generator=generator)
# Then move to desired device
noise = noise.to(device)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
z = mean + std * noise
z = self.scale_factor * z
return z
def get_moment_params(self, moments):
mean, logvar = torch.chunk(moments, 2, dim=1)
return mean, logvar
def encode(self, x):
moments = self.encode_moments(x)
# z = self.sample(moments)
# Instead of sampling from the diagonal gaussian, return its mode (mean)
mean, logvar = self.get_moment_params(moments)
return mean
def decode(self, z):
z = (1. / self.scale_factor) * z
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, inputs, fn):
if fn == 'encode_moments':
return self.encode_moments(inputs)
elif fn == 'encode':
return self.encode(inputs)
elif fn == 'decode':
return self.decode(inputs)
raise NotImplementedError
def freeze(self):
# ----Define Testing Utility Functions----
def get_test_autoencoder(pretrained_path, scale_factor=0.18215):
ddconfig = dict(
ch_mult=[1, 2, 4, 4],
vae_scale_factor = 2 ** (len(ddconfig['ch_mult']) - 1)
return TestAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor), vae_scale_factor
# Modified from diffusers.utils.randn_tensor
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
will always be created on CPU.
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
for i in range(batch_size)
latents = torch.cat(latents, dim=0).to(device)
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
# Sample from the autoencoder latent space directly instead of sampling the autoencoder moment.
def prepare_latents(
resolution = config.z_shape[-1] * vae_scale_factor
# Fix device to CPU for reproducibility.
latent_device = "cpu"
latent_torch_device = torch.device(latent_device)
generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
contexts = randn_tensor((config.n_samples, 77, config.clip_text_dim), generator=generator, device=latent_torch_device)
img_contexts = randn_tensor((config.n_samples, config.z_shape[0], config.z_shape[1], config.z_shape[2]), generator=generator, device=latent_torch_device)
clip_imgs = randn_tensor((config.n_samples, 1, config.clip_img_dim), generator=generator, device=latent_torch_device)
if config.mode in ['t2i', 't2i2t']:
prompts = [ config.prompt ] * config.n_samples
contexts = clip_text_model.encode(prompts)
elif config.mode in ['i2t', 'i2t2i']:
from PIL import Image
img_contexts = []
clip_imgs = []
def get_img_feature(image):
image = np.array(image).astype(np.uint8)
image = utils.center_crop(resolution, resolution, image)
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
image = (image / 127.5 - 1.0).astype(np.float32)
image = einops.rearrange(image, 'h w c -> 1 c h w')
image = torch.tensor(image, device=device)
# Get moments then get the mode of the moment (diagonal Gaussian) distribution
moments = autoencoder.encode_moments(image)
# Sample from the moments
moments = autoencoder.sample(moments, generator=generator, device=device)
return clip_img_feature, moments
image = Image.open(config.img).convert('RGB')
clip_img, img_context = get_img_feature(image)
img_contexts = img_contexts * config.n_samples
clip_imgs = clip_imgs * config.n_samples
img_contexts = torch.concat(img_contexts, dim=0)
clip_imgs = torch.stack(clip_imgs, dim=0)
contexts = contexts.to(device)
img_contexts = img_contexts.to(device)
clip_imgs = clip_imgs.to(device)
return contexts, img_contexts, clip_imgs
# ----END----
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
_betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
return _betas.numpy()
def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder):
resolution = config.z_shape[-1] * 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device)
img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)
if config.mode in ['t2i', 't2i2t']:
prompts = [ config.prompt ] * config.n_samples
contexts = clip_text_model.encode(prompts)
elif config.mode in ['i2t', 'i2t2i']:
from PIL import Image
img_contexts = []
clip_imgs = []
def get_img_feature(image):
image = np.array(image).astype(np.uint8)
image = utils.center_crop(resolution, resolution, image)
# clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
clip_inputs = clip_img_model_preprocess(images=image, return_tensors="pt")
clip_img_feature = clip_img_model(**clip_inputs).image_embeds
image = (image / 127.5 - 1.0).astype(np.float32)
image = einops.rearrange(image, 'h w c -> 1 c h w')
image = torch.tensor(image, device=device)
moments = autoencoder.encode_moments(image)
return clip_img_feature, moments
image = Image.open(config.img).convert('RGB')
clip_img, img_context = get_img_feature(image)
img_contexts = img_contexts * config.n_samples
clip_imgs = clip_imgs * config.n_samples
img_contexts = torch.concat(img_contexts, dim=0)
clip_imgs = torch.stack(clip_imgs, dim=0)
return contexts, img_contexts, clip_imgs
def unpreprocess(v): # to B C H W and [0, 1]
v = 0.5 * (v + 1.)
v.clamp_(0., 1.)
return v
def set_seed(seed: int):
def evaluate(config):
if config.get('benchmark', False):
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = config.sample.device
torch_device = torch.device(device)
# Instantiate generator
generator = torch.Generator(device=torch_device).manual_seed(config.seed)
config = ml_collections.FrozenConfigDict(config)
# utils.set_logger(log_level='info')
# utils.set_logger(log_level='debug', fname="./logs/test.txt")
_betas = stable_diffusion_beta_schedule()
N = len(_betas)
nnet = utils.get_nnet(**config.nnet)
logging.info(f'load nnet from {config.nnet_path}')
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i'
if use_caption_decoder:
from libs.caption_decoder import CaptionDecoder
caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
caption_decoder = None
clip_text_model = libs.clip.FrozenCLIPEmbedder(device=device)
# autoencoder = libs.autoencoder.get_model(**config.autoencoder)
# Load test autoencoder
autoencoder, vae_scale_factor = get_test_autoencoder(**config.autoencoder)
# print(f"VAE scale factor: {vae_scale_factor}")
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
empty_context = clip_text_model.encode([''])[0]
def split(x):
C, H, W = config.z_shape
z_dim = C * H * W
z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
return z, clip_img
def combine(z, clip_img):
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
return torch.concat([z, clip_img], dim=-1)
def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
1. calculate the conditional model output
2. calculate unconditional model output
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
3. return linear combination of conditional output and unconditional output
z, clip_img = split(x)
t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
logging.debug(f"Conditional VAE out: {z_out}")
logging.debug(f"Conditional VAE out shape: {z_out.shape}")
logging.debug(f"Conditional CLIP out: {clip_img_out}")
logging.debug(f"Conditional CLIP out shape: {clip_img_out.shape}")
x_out = combine(z_out, clip_img_out)
if config.sample.scale == 0.:
return x_out
if config.sample.t2i_cfg_mode == 'empty_token':
_empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
if use_caption_decoder:
_empty_context = caption_decoder.encode_prefix(_empty_context)
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text)
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
elif config.sample.t2i_cfg_mode == 'true_uncond':
# text_N = torch.randn_like(text) # 3 other possible choices
text_N = randn_tensor(text.shape, generator=generator, device=torch_device)
logging.debug(f"Unconditional random text: {text_N}")
logging.debug(f"Unconditional random text shape: {text_N.shape}")
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
logging.debug(f"Unconditional VAE out: {z_out_uncond}")
logging.debug(f"Unconditional VAE out shape: {z_out_uncond.shape}")
logging.debug(f"Unconditional CLIP out: {clip_img_out_uncond}")
logging.debug(f"Unconditional CLIP out shape: {clip_img_out_uncond.shape}")
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
raise NotImplementedError
return x_out + config.sample.scale * (x_out - x_out_uncond)
def i_nnet(x, timesteps):
z, clip_img = split(x)
# text = torch.randn(x.size(0), 77, config.text_dim, device=device)
text = randn_tensor((x.size(0), 77, config.text_dim), generator=generator, device=torch_device)
t_text = torch.ones_like(timesteps) * N
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text)
x_out = combine(z_out, clip_img_out)
return x_out
def t_nnet(x, timesteps):
# z = torch.randn(x.size(0), *config.z_shape, device=device)
# clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
z = randn_tensor((x.size(0), *config.z_shape), generator=generator, device=torch_device)
clip_img = randn_tensor((x.size(0), 1, config.clip_img_dim), generator=generator, device=torch_device)
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
return text_out
def i2t_nnet(x, timesteps, z, clip_img):
1. calculate the conditional model output
2. calculate unconditional model output
3. return linear combination of conditional output and unconditional output
t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps)
if config.sample.scale == 0.:
return text_out
# z_N = torch.randn_like(z) # 3 other possible choices
# clip_img_N = torch.randn_like(clip_img)
z_N = randn_tensor(z.shape, generator=generator, device=torch_device)
clip_img_N = randn_tensor(clip_img.shape, generator=generator, device=torch_device)
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
return text_out + config.sample.scale * (text_out - text_out_uncond)
def split_joint(x):
C, H, W = config.z_shape
z_dim = C * H * W
z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim)
return z, clip_img, text
def combine_joint(z, clip_img, text):
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
text = einops.rearrange(text, 'B L D -> B (L D)')
return torch.concat([z, clip_img, text], dim=-1)
def joint_nnet(x, timesteps):
logging.debug(f"Timestep: {timesteps}")
z, clip_img, text = split_joint(x)
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps)
logging.debug(f"Conditional VAE out: {z_out}")
logging.debug(f"Conditional VAE out shape: {z_out.shape}")
logging.debug(f"Conditional CLIP out: {clip_img_out}")
logging.debug(f"Conditional CLIP out shape: {clip_img_out.shape}")
logging.debug(f"Conditional text out: {text_out}")
logging.debug(f"Conditional text out shape: {text_out.shape}")
x_out = combine_joint(z_out, clip_img_out, text_out)
if config.sample.scale == 0.:
return x_out
# z_noise = torch.randn(x.size(0), *config.z_shape, device=device)
# clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
# text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device)
z_noise = randn_tensor((x.size(0), *config.z_shape), generator=generator, device=torch_device, dtype=z_out.dtype)
clip_img_noise = randn_tensor((x.size(0), 1, config.clip_img_dim), generator=generator, device=torch_device, dtype=clip_img_out.dtype)
text_noise = randn_tensor((x.size(0), 77, config.text_dim), generator=generator, device=torch_device, dtype=text_out.dtype)
_, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps)
logging.debug(f"Unconditional text out: {text_out_uncond}")
logging.debug(f"Unconditional text out shape: {text_out_uncond.shape}")
z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N)
logging.debug(f"Unconditional VAE out: {z_out_uncond}")
logging.debug(f"Unconditional VAE out shape: {z_out_uncond.shape}")
logging.debug(f"Unconditional CLIP out: {clip_img_out_uncond}")
logging.debug(f"Unconditional CLIP out shape: {clip_img_out_uncond.shape}")
x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
return x_out + config.sample.scale * (x_out - x_out_uncond)
def encode(_batch):
return autoencoder.encode(_batch)
def decode(_batch):
return autoencoder.decode(_batch)
# contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder)
contexts, img_contexts, clip_imgs = prepare_latents(
logging.debug(f"Text latents: {contexts}")
logging.debug(f"Text latents shape: {contexts.shape}")
contexts = contexts # the clip embedding of conditioned texts
contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) # the low dimensional version of the contexts, which is the input to the nnet
logging.debug(f"Low dim text latents: {contexts_low_dim}")
logging.debug(f"Low dim text latents shape: {contexts_low_dim.shape}")
img_contexts = img_contexts # img_contexts is the autoencoder moment
# z_img = autoencoder.sample(img_contexts, generator=cpu_generator, device=device)
z_img = img_contexts # sample autoencoder latents directly, no need to call sample()
clip_imgs = clip_imgs # the clip embedding of conditioned image
logging.debug(f"VAE latents: {z_img}")
logging.debug(f"VAE latents shape: {z_img.shape}")
logging.debug(f"CLIP latents: {clip_imgs}")
logging.debug(f"CLIP latents shape: {clip_imgs.shape}")
if config.mode in ['t2i', 't2i2t']:
_n_samples = contexts_low_dim.size(0)
elif config.mode in ['i2t', 'i2t2i']:
_n_samples = img_contexts.size(0)
_n_samples = config.n_samples
def sample_fn(mode, **kwargs):
# _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
# _clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
# _text_init = torch.randn(_n_samples, 77, config.text_dim, device=device)
_z_init = randn_tensor((_n_samples, *config.z_shape), generator=generator, device=torch_device)
_clip_img_init = randn_tensor((_n_samples, 1, config.clip_img_dim), generator=generator, device=torch_device)
_text_init = randn_tensor((_n_samples, 77, config.text_dim), generator=generator, device=torch_device)
if mode == 'joint':
_x_init = combine_joint(_z_init, _clip_img_init, _text_init)
elif mode in ['t2i', 'i']:
_x_init = combine(_z_init, _clip_img_init)
elif mode in ['i2t', 't']:
_x_init = _text_init
logging.debug(f"Latents: {_x_init}")
logging.debug(f"Latents shape: {_x_init.shape}")
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * N
if mode == 'joint':
return joint_nnet(x, t)
elif mode == 't2i':
return t2i_nnet(x, t, **kwargs)
elif mode == 'i2t':
return i2t_nnet(x, t, **kwargs)
elif mode == 'i':
return i_nnet(x, t)
elif mode == 't':
return t_nnet(x, t)
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
with torch.no_grad():
with torch.autocast(device_type=device):
start_time = time.time()
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
end_time = time.time()
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
# os.makedirs(config.output_path, exist_ok=True)
if mode == 'joint':
_z, _clip_img, _text = split_joint(x)
return _z, _clip_img, _text
elif mode in ['t2i', 'i']:
_z, _clip_img = split(x)
return _z, _clip_img
elif mode in ['i2t', 't']:
return x
def test_sample_fn(mode, **kwargs):
if mode == 'joint':
_x_init = combine_joint(z_img, clip_imgs, contexts_low_dim)
elif mode in ['t2i', 'i']:
_x_init = combine(z_img, clip_imgs)
elif mode in ['i2t', 't']:
_x_init = contexts_low_dim
logging.debug(f"Latents: {_x_init}")
logging.debug(f"Latents shape: {_x_init.shape}")
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
def model_fn(x, t_continuous):
t = t_continuous * N
if mode == 'joint':
noise_pred = joint_nnet(x, t)
logging.debug(f"Noise pred for time {t}: {noise_pred}")
logging.debug(f"Noise pred for time {t} shape: {noise_pred.shape}")
return noise_pred
# return joint_nnet(x, t)
elif mode == 't2i':
noise_pred = t2i_nnet(x, t, **kwargs)
logging.debug(f"Noise pred for time {t}: {noise_pred}")
logging.debug(f"Noise pred for time {t} shape: {noise_pred.shape}")
return noise_pred
# return t2i_nnet(x, t, **kwargs)
elif mode == 'i2t':
return i2t_nnet(x, t, **kwargs)
elif mode == 'i':
return i_nnet(x, t)
elif mode == 't':
return t_nnet(x, t)
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
with torch.no_grad():
# Remove autocast to run in full precision for testing on CPU
start_time = time.time()
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
end_time = time.time()
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
logging.debug(f"Full UNet sample: {x}")
logging.debug(f"Full UNet sample shape: {x.shape}")
# os.makedirs(config.output_path, exist_ok=True)
if mode == 'joint':
_z, _clip_img, _text = split_joint(x)
return _z, _clip_img, _text
elif mode in ['t2i', 'i']:
_z, _clip_img = split(x)
return _z, _clip_img
elif mode in ['i2t', 't']:
return x
output_images = None
output_text = None
if config.mode in ['joint']:
# _z, _clip_img, _text = sample_fn(config.mode)
_z, _clip_img, _text = test_sample_fn(config.mode)
logging.debug(f"Text output: {_text}")
logging.debug(f"Text output shape: {_text.shape}")
logging.debug(f"VAE output: {_z}")
logging.debug(f"VAE output shape: {_z.shape}")
logging.debug(f"CLIP output: {_clip_img}")
logging.debug(f"CLIP output shape: {_clip_img.shape}")
samples = unpreprocess(decode(_z))
logging.debug(f"VAE decoded sample: {samples}")
logging.debug(f"VAE decoded sample shape: {samples.shape}")
prompts = caption_decoder.generate_captions(_text)
logging.debug(f"Generated text: {prompts}")
output_images = samples
output_text = prompts
elif config.mode in ['t2i', 'i', 'i2t2i']:
if config.mode == 't2i':
# _z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) # conditioned on the text embedding
_z, _clip_img = test_sample_fn(config.mode, text=contexts_low_dim)
logging.debug(f"VAE output: {_z}")
logging.debug(f"VAE output shape: {_z.shape}")
logging.debug(f"CLIP output: {_clip_img}")
logging.debug(f"CLIP output shape: {_clip_img.shape}")
elif config.mode == 'i':
# _z, _clip_img = sample_fn(config.mode)
_z, _clip_img = test_sample_fn(config.mode)
elif config.mode == 'i2t2i':
_text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
_z, _clip_img = sample_fn('t2i', text=_text)
samples = unpreprocess(decode(_z))
output_images = samples
elif config.mode in ['i2t', 't', 't2i2t']:
if config.mode == 'i2t':
# _text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
_text = test_sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
elif config.mode == 't':
# _text = sample_fn(config.mode)
_text = test_sample_fn(config.mode)
elif config.mode == 't2i2t':
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
_text = sample_fn('i2t', z=_z, clip_img=_clip_img)
samples = caption_decoder.generate_captions(_text)
output_text = samples
print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
# print(f'\nresults are saved in {os.path.join(config.output_path, config.mode)} :)')
return output_images, output_text
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 0
config.pred = 'noise_pred'
config.z_shape = (4, 64, 64)
config.clip_img_dim = 512
config.clip_text_dim = 768
config.text_dim = 64 # reduce dimension
config.autoencoder = d(
config.caption_decoder = d(
config.nnet = d(
config.sample = d(
return config
def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None):
config = get_config()
config.nnet_path = "models/uvit_v0.pth"
config.n_samples = 1
config.nrow = 1
config.mode = mode
config.prompt = prompt
config.img = image
config.sample.sample_steps = sample_steps
config.sample.scale = scale
if seed is not None:
config.seed = seed
sample_images, sample_text = evaluate(config)
return sample_images, sample_text