|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import torch |
|
|
|
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion |
|
from ldm.modules.karlo.kakao.modules.xf import PriorTransformer |
|
|
|
|
|
class PriorDiffusionModel(torch.nn.Module): |
|
""" |
|
A prior that generates clip image feature based on the text prompt. |
|
|
|
:param config: yaml config to define the decoder. |
|
:param tokenizer: tokenizer used in clip. |
|
:param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). |
|
:param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). |
|
""" |
|
|
|
def __init__(self, config, tokenizer, clip_mean, clip_std): |
|
super().__init__() |
|
|
|
self._conf = config |
|
self._model_conf = config.model.hparams |
|
self._diffusion_kwargs = dict( |
|
steps=config.diffusion.steps, |
|
learn_sigma=config.diffusion.learn_sigma, |
|
sigma_small=config.diffusion.sigma_small, |
|
noise_schedule=config.diffusion.noise_schedule, |
|
use_kl=config.diffusion.use_kl, |
|
predict_xstart=config.diffusion.predict_xstart, |
|
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, |
|
timestep_respacing=config.diffusion.timestep_respacing, |
|
) |
|
self._tokenizer = tokenizer |
|
|
|
self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) |
|
self.register_buffer("clip_std", clip_std[None, :], persistent=False) |
|
|
|
causal_mask = self.get_causal_mask() |
|
self.register_buffer("causal_mask", causal_mask, persistent=False) |
|
|
|
self.model = PriorTransformer( |
|
text_ctx=self._model_conf.text_ctx, |
|
xf_width=self._model_conf.xf_width, |
|
xf_layers=self._model_conf.xf_layers, |
|
xf_heads=self._model_conf.xf_heads, |
|
xf_final_ln=self._model_conf.xf_final_ln, |
|
clip_dim=self._model_conf.clip_dim, |
|
) |
|
|
|
cf_token, cf_mask = self.set_cf_text_tensor() |
|
self.register_buffer("cf_token", cf_token, persistent=False) |
|
self.register_buffer("cf_mask", cf_mask, persistent=False) |
|
|
|
@classmethod |
|
def load_from_checkpoint( |
|
cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True |
|
): |
|
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
|
|
model = cls(config, tokenizer, clip_mean, clip_std) |
|
model.load_state_dict(ckpt, strict=strict) |
|
return model |
|
|
|
def set_cf_text_tensor(self): |
|
return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) |
|
|
|
def get_sample_fn(self, timestep_respacing): |
|
use_ddim = timestep_respacing.startswith(("ddim", "fast")) |
|
|
|
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) |
|
diffusion_kwargs.update(timestep_respacing=timestep_respacing) |
|
diffusion = create_gaussian_diffusion(**diffusion_kwargs) |
|
sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop |
|
|
|
return sample_fn |
|
|
|
def get_causal_mask(self): |
|
seq_len = self._model_conf.text_ctx + 4 |
|
mask = torch.empty(seq_len, seq_len) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
mask = mask[None, ...] |
|
return mask |
|
|
|
def forward( |
|
self, |
|
txt_feat, |
|
txt_feat_seq, |
|
mask, |
|
cf_guidance_scales=None, |
|
timestep_respacing=None, |
|
denoised_fn=True, |
|
): |
|
|
|
assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) |
|
|
|
bsz_ = txt_feat.shape[0] |
|
bsz = bsz_ // 2 |
|
|
|
def guided_model_fn(x_t, ts, **kwargs): |
|
half = x_t[: len(x_t) // 2] |
|
combined = torch.cat([half, half], dim=0) |
|
model_out = self.model(combined, ts, **kwargs) |
|
eps, rest = ( |
|
model_out[:, : int(x_t.shape[1])], |
|
model_out[:, int(x_t.shape[1]) :], |
|
) |
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( |
|
cond_eps - uncond_eps |
|
) |
|
eps = torch.cat([half_eps, half_eps], dim=0) |
|
return torch.cat([eps, rest], dim=1) |
|
|
|
cond = { |
|
"text_emb": txt_feat, |
|
"text_enc": txt_feat_seq, |
|
"mask": mask, |
|
"causal_mask": self.causal_mask, |
|
} |
|
sample_fn = self.get_sample_fn(timestep_respacing) |
|
sample = sample_fn( |
|
guided_model_fn, |
|
(bsz_, self.model.clip_dim), |
|
noise=None, |
|
device=txt_feat.device, |
|
clip_denoised=False, |
|
denoised_fn=lambda x: torch.clamp(x, -10, 10), |
|
model_kwargs=cond, |
|
) |
|
sample = (sample * self.clip_std) + self.clip_mean |
|
|
|
return sample[:bsz] |
|
|