AeroXi's picture
Upload folder using huggingface_hub
ece766c
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
class PriorDiffusionModel(torch.nn.Module):
"""
A prior that generates clip image feature based on the text prompt.
:param config: yaml config to define the decoder.
:param tokenizer: tokenizer used in clip.
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
"""
def __init__(self, config, tokenizer, clip_mean, clip_std):
super().__init__()
self._conf = config
self._model_conf = config.model.hparams
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
timestep_respacing=config.diffusion.timestep_respacing,
)
self._tokenizer = tokenizer
self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
self.register_buffer("clip_std", clip_std[None, :], persistent=False)
causal_mask = self.get_causal_mask()
self.register_buffer("causal_mask", causal_mask, persistent=False)
self.model = PriorTransformer(
text_ctx=self._model_conf.text_ctx,
xf_width=self._model_conf.xf_width,
xf_layers=self._model_conf.xf_layers,
xf_heads=self._model_conf.xf_heads,
xf_final_ln=self._model_conf.xf_final_ln,
clip_dim=self._model_conf.clip_dim,
)
cf_token, cf_mask = self.set_cf_text_tensor()
self.register_buffer("cf_token", cf_token, persistent=False)
self.register_buffer("cf_mask", cf_mask, persistent=False)
@classmethod
def load_from_checkpoint(
cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config, tokenizer, clip_mean, clip_std)
model.load_state_dict(ckpt, strict=strict)
return model
def set_cf_text_tensor(self):
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
def get_sample_fn(self, timestep_respacing):
use_ddim = timestep_respacing.startswith(("ddim", "fast"))
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
return sample_fn
def get_causal_mask(self):
seq_len = self._model_conf.text_ctx + 4
mask = torch.empty(seq_len, seq_len)
mask.fill_(float("-inf"))
mask.triu_(1)
mask = mask[None, ...]
return mask
def forward(
self,
txt_feat,
txt_feat_seq,
mask,
cf_guidance_scales=None,
timestep_respacing=None,
denoised_fn=True,
):
# cfg should be enabled in inference
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
bsz_ = txt_feat.shape[0]
bsz = bsz_ // 2
def guided_model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = (
model_out[:, : int(x_t.shape[1])],
model_out[:, int(x_t.shape[1]) :],
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
cond_eps - uncond_eps
)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
cond = {
"text_emb": txt_feat,
"text_enc": txt_feat_seq,
"mask": mask,
"causal_mask": self.causal_mask,
}
sample_fn = self.get_sample_fn(timestep_respacing)
sample = sample_fn(
guided_model_fn,
(bsz_, self.model.clip_dim),
noise=None,
device=txt_feat.device,
clip_denoised=False,
denoised_fn=lambda x: torch.clamp(x, -10, 10),
model_kwargs=cond,
)
sample = (sample * self.clip_std) + self.clip_mean
return sample[:bsz]