|
import ml_collections |
|
import torch |
|
import random |
|
from absl import logging |
|
import einops |
|
from torchvision.utils import save_image, make_grid |
|
import torchvision.transforms as standard_transforms |
|
import numpy as np |
|
import clip |
|
from PIL import Image |
|
import time |
|
import os |
|
|
|
from libs.autoencoder import get_model |
|
from libs.clip import FrozenCLIPEmbedder |
|
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver |
|
from utils import center_crop, set_logger, get_nnet |
|
|
|
|
|
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000): |
|
_betas = ( |
|
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 |
|
) |
|
return _betas.numpy() |
|
|
|
|
|
def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder): |
|
resolution = config.z_shape[-1] * 8 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device) |
|
img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2]) |
|
clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim) |
|
|
|
if config.mode in ['t2i', 't2i2t']: |
|
prompts = [ config.prompt ] * config.n_samples |
|
contexts = clip_text_model.encode(prompts) |
|
|
|
elif config.mode in ['i2t', 'i2t2i']: |
|
from PIL import Image |
|
img_contexts = [] |
|
clip_imgs = [] |
|
|
|
def get_img_feature(image): |
|
image = np.array(image).astype(np.uint8) |
|
image = center_crop(resolution, resolution, image) |
|
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device)) |
|
|
|
image = (image / 127.5 - 1.0).astype(np.float32) |
|
image = einops.rearrange(image, 'h w c -> 1 c h w') |
|
image = torch.tensor(image, device=device) |
|
moments = autoencoder.encode_moments(image) |
|
|
|
return clip_img_feature, moments |
|
|
|
image = Image.open(config.img).convert('RGB') |
|
clip_img, img_context = get_img_feature(image) |
|
|
|
img_contexts.append(img_context) |
|
clip_imgs.append(clip_img) |
|
img_contexts = img_contexts * config.n_samples |
|
clip_imgs = clip_imgs * config.n_samples |
|
|
|
img_contexts = torch.concat(img_contexts, dim=0) |
|
clip_imgs = torch.stack(clip_imgs, dim=0) |
|
|
|
return contexts, img_contexts, clip_imgs |
|
|
|
|
|
def unpreprocess(v): |
|
v = 0.5 * (v + 1.) |
|
v.clamp_(0., 1.) |
|
return v |
|
|
|
|
|
def set_seed(seed: int): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def evaluate(config): |
|
if config.get('benchmark', False): |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = False |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
set_seed(config.seed) |
|
|
|
config = ml_collections.FrozenConfigDict(config) |
|
set_logger(log_level='info') |
|
|
|
_betas = stable_diffusion_beta_schedule() |
|
N = len(_betas) |
|
|
|
nnet = get_nnet(**config.nnet) |
|
logging.info(f'load nnet from {config.nnet_path}') |
|
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu')) |
|
nnet.to(device) |
|
nnet.eval() |
|
|
|
use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i' |
|
if use_caption_decoder: |
|
from libs.caption_decoder import CaptionDecoder |
|
caption_decoder = CaptionDecoder(device=device, **config.caption_decoder) |
|
else: |
|
caption_decoder = None |
|
|
|
clip_text_model = FrozenCLIPEmbedder(device=device) |
|
clip_text_model.eval() |
|
clip_text_model.to(device) |
|
|
|
autoencoder = get_model(**config.autoencoder) |
|
autoencoder.to(device) |
|
|
|
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False) |
|
|
|
empty_context = clip_text_model.encode([''])[0] |
|
|
|
def split(x): |
|
C, H, W = config.z_shape |
|
z_dim = C * H * W |
|
z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1) |
|
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) |
|
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim) |
|
return z, clip_img |
|
|
|
|
|
def combine(z, clip_img): |
|
z = einops.rearrange(z, 'B C H W -> B (C H W)') |
|
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') |
|
return torch.concat([z, clip_img], dim=-1) |
|
|
|
|
|
def t2i_nnet(x, timesteps, text): |
|
""" |
|
1. calculate the conditional model output |
|
2. calculate unconditional model output |
|
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string |
|
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method |
|
3. return linear combination of conditional output and unconditional output |
|
""" |
|
z, clip_img = split(x) |
|
|
|
t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device) |
|
|
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine(z_out, clip_img_out) |
|
|
|
if config.sample.scale == 0.: |
|
return x_out |
|
|
|
if config.sample.t2i_cfg_mode == 'empty_token': |
|
_empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0)) |
|
if use_caption_decoder: |
|
_empty_context = caption_decoder.encode_prefix(_empty_context) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond) |
|
elif config.sample.t2i_cfg_mode == 'true_uncond': |
|
text_N = torch.randn_like(text) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond) |
|
else: |
|
raise NotImplementedError |
|
|
|
return x_out + config.sample.scale * (x_out - x_out_uncond) |
|
|
|
|
|
def i_nnet(x, timesteps): |
|
z, clip_img = split(x) |
|
text = torch.randn(x.size(0), 77, config.text_dim, device=device) |
|
t_text = torch.ones_like(timesteps) * N |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text, |
|
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine(z_out, clip_img_out) |
|
return x_out |
|
|
|
def t_nnet(x, timesteps): |
|
z = torch.randn(x.size(0), *config.z_shape, device=device) |
|
clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device) |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
return text_out |
|
|
|
def i2t_nnet(x, timesteps, z, clip_img): |
|
""" |
|
1. calculate the conditional model output |
|
2. calculate unconditional model output |
|
3. return linear combination of conditional output and unconditional output |
|
""" |
|
t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device) |
|
|
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps, |
|
data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + config.data_type) |
|
|
|
if config.sample.scale == 0.: |
|
return text_out |
|
|
|
z_N = torch.randn_like(z) |
|
clip_img_N = torch.randn_like(clip_img) |
|
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
|
|
return text_out + config.sample.scale * (text_out - text_out_uncond) |
|
|
|
def split_joint(x): |
|
C, H, W = config.z_shape |
|
z_dim = C * H * W |
|
z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1) |
|
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W) |
|
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim) |
|
text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim) |
|
return z, clip_img, text |
|
|
|
def combine_joint(z, clip_img, text): |
|
z = einops.rearrange(z, 'B C H W -> B (C H W)') |
|
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)') |
|
text = einops.rearrange(text, 'B L D -> B (L D)') |
|
return torch.concat([z, clip_img, text], dim=-1) |
|
|
|
def joint_nnet(x, timesteps): |
|
z, clip_img, text = split_joint(x) |
|
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
x_out = combine_joint(z_out, clip_img_out, text_out) |
|
|
|
if config.sample.scale == 0.: |
|
return x_out |
|
|
|
z_noise = torch.randn(x.size(0), *config.z_shape, device=device) |
|
clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device) |
|
text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device) |
|
|
|
_, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N, |
|
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type) |
|
|
|
x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond) |
|
|
|
return x_out + config.sample.scale * (x_out - x_out_uncond) |
|
|
|
@torch.cuda.amp.autocast() |
|
def encode(_batch): |
|
return autoencoder.encode(_batch) |
|
|
|
@torch.cuda.amp.autocast() |
|
def decode(_batch): |
|
return autoencoder.decode(_batch) |
|
|
|
|
|
logging.info(config.sample) |
|
logging.info(f'N={N}') |
|
|
|
contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder) |
|
|
|
contexts = contexts |
|
contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) |
|
|
|
img_contexts = img_contexts |
|
z_img = autoencoder.sample(img_contexts) |
|
clip_imgs = clip_imgs |
|
|
|
if config.mode in ['t2i', 't2i2t']: |
|
_n_samples = contexts_low_dim.size(0) |
|
elif config.mode in ['i2t', 'i2t2i']: |
|
_n_samples = img_contexts.size(0) |
|
else: |
|
_n_samples = config.n_samples |
|
|
|
|
|
def sample_fn(mode, **kwargs): |
|
|
|
_z_init = torch.randn(_n_samples, *config.z_shape, device=device) |
|
_clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device) |
|
_text_init = torch.randn(_n_samples, 77, config.text_dim, device=device) |
|
if mode == 'joint': |
|
_x_init = combine_joint(_z_init, _clip_img_init, _text_init) |
|
elif mode in ['t2i', 'i']: |
|
_x_init = combine(_z_init, _clip_img_init) |
|
elif mode in ['i2t', 't']: |
|
_x_init = _text_init |
|
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float()) |
|
|
|
def model_fn(x, t_continuous): |
|
t = t_continuous * N |
|
if mode == 'joint': |
|
return joint_nnet(x, t) |
|
elif mode == 't2i': |
|
return t2i_nnet(x, t, **kwargs) |
|
elif mode == 'i2t': |
|
return i2t_nnet(x, t, **kwargs) |
|
elif mode == 'i': |
|
return i_nnet(x, t) |
|
elif mode == 't': |
|
return t_nnet(x, t) |
|
|
|
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) |
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device): |
|
start_time = time.time() |
|
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) |
|
end_time = time.time() |
|
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s') |
|
|
|
|
|
if mode == 'joint': |
|
_z, _clip_img, _text = split_joint(x) |
|
return _z, _clip_img, _text |
|
elif mode in ['t2i', 'i']: |
|
_z, _clip_img = split(x) |
|
return _z, _clip_img |
|
elif mode in ['i2t', 't']: |
|
return x |
|
|
|
output_images = None |
|
output_text = None |
|
|
|
if config.mode in ['joint']: |
|
_z, _clip_img, _text = sample_fn(config.mode) |
|
samples = unpreprocess(decode(_z)) |
|
prompts = caption_decoder.generate_captions(_text) |
|
|
|
output_images = samples |
|
output_text = prompts |
|
|
|
elif config.mode in ['t2i', 'i', 'i2t2i']: |
|
if config.mode == 't2i': |
|
_z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) |
|
elif config.mode == 'i': |
|
_z, _clip_img = sample_fn(config.mode) |
|
elif config.mode == 'i2t2i': |
|
_text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) |
|
_z, _clip_img = sample_fn('t2i', text=_text) |
|
samples = unpreprocess(decode(_z)) |
|
output_images = samples |
|
|
|
|
|
elif config.mode in ['i2t', 't', 't2i2t']: |
|
if config.mode == 'i2t': |
|
_text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) |
|
elif config.mode == 't': |
|
_text = sample_fn(config.mode) |
|
elif config.mode == 't2i2t': |
|
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim) |
|
_text = sample_fn('i2t', z=_z, clip_img=_clip_img) |
|
samples = caption_decoder.generate_captions(_text) |
|
output_text = samples |
|
|
|
print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') |
|
|
|
|
|
|
|
if output_images is not None: |
|
for sample in output_images: |
|
sample = standard_transforms.ToPILImage()(sample) |
|
|
|
return output_images, output_text |
|
|
|
|
|
|
|
def d(**kwargs): |
|
"""Helper of creating a config dict.""" |
|
return ml_collections.ConfigDict(initial_dictionary=kwargs) |
|
|
|
|
|
def get_config(): |
|
config = ml_collections.ConfigDict() |
|
|
|
config.seed = 1234 |
|
config.pred = 'noise_pred' |
|
config.z_shape = (4, 64, 64) |
|
config.clip_img_dim = 512 |
|
config.clip_text_dim = 768 |
|
config.text_dim = 64 |
|
config.data_type = 1 |
|
|
|
config.autoencoder = d( |
|
pretrained_path='models/autoencoder_kl.pth', |
|
) |
|
|
|
config.caption_decoder = d( |
|
pretrained_path="models/caption_decoder.pth", |
|
hidden_dim=config.get_ref('text_dim') |
|
) |
|
|
|
config.nnet = d( |
|
name='uvit_multi_post_ln_v1', |
|
img_size=64, |
|
in_chans=4, |
|
patch_size=2, |
|
embed_dim=1536, |
|
depth=30, |
|
num_heads=24, |
|
mlp_ratio=4, |
|
qkv_bias=False, |
|
pos_drop_rate=0., |
|
drop_rate=0., |
|
attn_drop_rate=0., |
|
mlp_time_embed=False, |
|
text_dim=config.get_ref('text_dim'), |
|
num_text_tokens=77, |
|
clip_img_dim=config.get_ref('clip_img_dim'), |
|
use_checkpoint=True |
|
) |
|
|
|
config.sample = d( |
|
sample_steps=50, |
|
scale=7., |
|
t2i_cfg_mode='true_uncond' |
|
) |
|
|
|
return config |
|
|
|
|
|
def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None): |
|
config = get_config() |
|
|
|
config.nnet_path = "models/uvit_v1.pth" |
|
config.n_samples = 1 |
|
config.nrow = 1 |
|
|
|
config.mode = mode |
|
config.prompt = prompt |
|
config.img = image |
|
|
|
config.sample.sample_steps = sample_steps |
|
config.sample.scale = scale |
|
if seed is not None: |
|
config.seed = seed |
|
|
|
sample_images, sample_text = evaluate(config) |
|
return sample_images, sample_text |
|
|
|
|