unidiffuser / model.py
hysts's picture
hysts HF staff
Add files
639c25d
raw
history blame
No virus
20.5 kB
from __future__ import annotations
import pathlib
import random
import sys
from typing import Callable
import clip
import einops
import numpy as np
import PIL.Image
import torch
from huggingface_hub import snapshot_download
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'unidiffuser'
sys.path.append(submodule_dir.as_posix())
import utils
from configs.sample_unidiffuser_v1 import get_config
from dpm_solver_pp import DPM_Solver, NoiseScheduleVP
from libs.autoencoder import FrozenAutoencoderKL
from libs.autoencoder import get_model as get_autoencoder
from libs.caption_decoder import CaptionDecoder
from libs.clip import FrozenCLIPEmbedder
model_dir = repo_dir / 'models'
if not model_dir.exists():
snapshot_download('thu-ml/unidiffuser-v1',
repo_type='model',
local_dir=model_dir)
def stable_diffusion_beta_schedule(linear_start=0.00085,
linear_end=0.0120,
n_timestep=1000):
_betas = (torch.linspace(linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
return _betas.numpy()
class Model:
def __init__(self):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
self.config = get_config()
self.nnet = self.load_model()
self.caption_decoder = CaptionDecoder(device=self.device,
**self.config.caption_decoder)
self.clip_text_model = self.load_clip_text_model()
self.autoencoder = self.load_autoencoder()
self.clip_img_model, self.clip_img_model_preprocess = clip.load(
'ViT-B/32', device=self.device, jit=False)
self.empty_context = self.clip_text_model.encode([''])[0]
self.betas = stable_diffusion_beta_schedule()
self.N = len(self.betas)
@property
def use_caption_decoder(self) -> bool:
return (self.config.text_dim < self.config.clip_text_dim
or self.config.mode != 't2i')
def load_model(self,
model_path: str = 'models/uvit_v1.pth') -> torch.nn.Module:
model = utils.get_nnet(**self.config.nnet)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(self.device)
model.eval()
return model
def load_clip_text_model(self) -> FrozenCLIPEmbedder:
clip_text_model = FrozenCLIPEmbedder(device=self.device)
clip_text_model.to(self.device)
clip_text_model.eval()
return clip_text_model
def load_autoencoder(self) -> FrozenAutoencoderKL:
autoencoder = get_autoencoder(**self.config.autoencoder)
autoencoder.to(self.device)
return autoencoder
def split(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
C, H, W = self.config.z_shape
z_dim = C * H * W
z, clip_img = x.split([z_dim, self.config.clip_img_dim], dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img,
'B (L D) -> B L D',
L=1,
D=self.config.clip_img_dim)
return z, clip_img
@staticmethod
def combine(z, clip_img):
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
return torch.concat([z, clip_img], dim=-1)
def t2i_nnet(
self, x, timesteps, text
): # text is the low dimension version of the text clip embedding
"""
1. calculate the conditional model output
2. calculate unconditional model output
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
3. return linear combination of conditional output and unconditional output
"""
z, clip_img = self.split(x)
t_text = torch.zeros(timesteps.size(0),
dtype=torch.int,
device=self.device)
z_out, clip_img_out, text_out = self.nnet(
z,
clip_img,
text=text,
t_img=timesteps,
t_text=t_text,
data_type=torch.zeros_like(
t_text, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out = self.combine(z_out, clip_img_out)
if self.config.sample.scale == 0.:
return x_out
if self.config.sample.t2i_cfg_mode == 'empty_token':
_empty_context = einops.repeat(self.empty_context,
'L D -> B L D',
B=x.size(0))
if self.use_caption_decoder:
_empty_context = self.caption_decoder.encode_prefix(
_empty_context)
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
z,
clip_img,
text=_empty_context,
t_img=timesteps,
t_text=t_text,
data_type=torch.zeros_like(
t_text, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
elif self.config.sample.t2i_cfg_mode == 'true_uncond':
text_N = torch.randn_like(text) # 3 other possible choices
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
z,
clip_img,
text=text_N,
t_img=timesteps,
t_text=torch.ones_like(timesteps) * self.N,
data_type=torch.zeros_like(
t_text, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out_uncond = self.combine(z_out_uncond, clip_img_out_uncond)
else:
raise NotImplementedError
return x_out + self.config.sample.scale * (x_out - x_out_uncond)
def i_nnet(self, x, timesteps):
z, clip_img = self.split(x)
text = torch.randn(x.size(0),
77,
self.config.text_dim,
device=self.device)
t_text = torch.ones_like(timesteps) * self.N
z_out, clip_img_out, text_out = self.nnet(
z,
clip_img,
text=text,
t_img=timesteps,
t_text=t_text,
data_type=torch.zeros_like(
t_text, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out = self.combine(z_out, clip_img_out)
return x_out
def t_nnet(self, x, timesteps):
z = torch.randn(x.size(0), *self.config.z_shape, device=self.device)
clip_img = torch.randn(x.size(0),
1,
self.config.clip_img_dim,
device=self.device)
z_out, clip_img_out, text_out = self.nnet(
z,
clip_img,
text=x,
t_img=torch.ones_like(timesteps) * self.N,
t_text=timesteps,
data_type=torch.zeros_like(
timesteps, device=self.device, dtype=torch.int) +
self.config.data_type)
return text_out
def i2t_nnet(self, x, timesteps, z, clip_img):
"""
1. calculate the conditional model output
2. calculate unconditional model output
3. return linear combination of conditional output and unconditional output
"""
t_img = torch.zeros(timesteps.size(0),
dtype=torch.int,
device=self.device)
z_out, clip_img_out, text_out = self.nnet(
z,
clip_img,
text=x,
t_img=t_img,
t_text=timesteps,
data_type=torch.zeros_like(
t_img, device=self.device, dtype=torch.int) +
self.config.data_type)
if self.config.sample.scale == 0.:
return text_out
z_N = torch.randn_like(z) # 3 other possible choices
clip_img_N = torch.randn_like(clip_img)
z_out_uncond, clip_img_out_uncond, text_out_uncond = self.nnet(
z_N,
clip_img_N,
text=x,
t_img=torch.ones_like(timesteps) * self.N,
t_text=timesteps,
data_type=torch.zeros_like(
timesteps, device=self.device, dtype=torch.int) +
self.config.data_type)
return text_out + self.config.sample.scale * (text_out -
text_out_uncond)
def split_joint(self, x):
C, H, W = self.config.z_shape
z_dim = C * H * W
z, clip_img, text = x.split(
[z_dim, self.config.clip_img_dim, 77 * self.config.text_dim],
dim=1)
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
clip_img = einops.rearrange(clip_img,
'B (L D) -> B L D',
L=1,
D=self.config.clip_img_dim)
text = einops.rearrange(text,
'B (L D) -> B L D',
L=77,
D=self.config.text_dim)
return z, clip_img, text
@staticmethod
def combine_joint(z: torch.Tensor, clip_img: torch.Tensor,
text: torch.Tensor) -> torch.Tensor:
z = einops.rearrange(z, 'B C H W -> B (C H W)')
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
text = einops.rearrange(text, 'B L D -> B (L D)')
return torch.concat([z, clip_img, text], dim=-1)
def joint_nnet(self, x, timesteps):
z, clip_img, text = self.split_joint(x)
z_out, clip_img_out, text_out = self.nnet(
z,
clip_img,
text=text,
t_img=timesteps,
t_text=timesteps,
data_type=torch.zeros_like(
timesteps, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out = self.combine_joint(z_out, clip_img_out, text_out)
if self.config.sample.scale == 0.:
return x_out
z_noise = torch.randn(x.size(0),
*self.config.z_shape,
device=self.device)
clip_img_noise = torch.randn(x.size(0),
1,
self.config.clip_img_dim,
device=self.device)
text_noise = torch.randn(x.size(0),
77,
self.config.text_dim,
device=self.device)
_, _, text_out_uncond = self.nnet(
z_noise,
clip_img_noise,
text=text,
t_img=torch.ones_like(timesteps) * self.N,
t_text=timesteps,
data_type=torch.zeros_like(
timesteps, device=self.device, dtype=torch.int) +
self.config.data_type)
z_out_uncond, clip_img_out_uncond, _ = self.nnet(
z,
clip_img,
text=text_noise,
t_img=timesteps,
t_text=torch.ones_like(timesteps) * self.N,
data_type=torch.zeros_like(
timesteps, device=self.device, dtype=torch.int) +
self.config.data_type)
x_out_uncond = self.combine_joint(z_out_uncond, clip_img_out_uncond,
text_out_uncond)
return x_out + self.config.sample.scale * (x_out - x_out_uncond)
@torch.cuda.amp.autocast()
def encode(self, _batch):
return self.autoencoder.encode(_batch)
@torch.cuda.amp.autocast()
def decode(self, _batch):
return self.autoencoder.decode(_batch)
def prepare_contexts(
self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
resolution = self.config.z_shape[-1] * 8
contexts = torch.randn(self.config.n_samples, 77,
self.config.clip_text_dim).to(self.device)
img_contexts = torch.randn(self.config.n_samples,
2 * self.config.z_shape[0],
self.config.z_shape[1],
self.config.z_shape[2])
clip_imgs = torch.randn(self.config.n_samples, 1,
self.config.clip_img_dim)
if self.config.mode in ['t2i', 't2i2t']:
prompts = [self.config.prompt] * self.config.n_samples
contexts = self.clip_text_model.encode(prompts)
elif self.config.mode in ['i2t', 'i2t2i']:
img_contexts = []
clip_imgs = []
def get_img_feature(image):
image = np.array(image).astype(np.uint8)
image = utils.center_crop(resolution, resolution, image)
clip_img_feature = self.clip_img_model.encode_image(
self.clip_img_model_preprocess(
PIL.Image.fromarray(image)).unsqueeze(0).to(
self.device))
image = (image / 127.5 - 1.0).astype(np.float32)
image = einops.rearrange(image, 'h w c -> 1 c h w')
image = torch.tensor(image, device=self.device)
moments = self.autoencoder.encode_moments(image)
return clip_img_feature, moments
image = PIL.Image.open(self.config.img).convert('RGB')
clip_img, img_context = get_img_feature(image)
img_contexts.append(img_context)
clip_imgs.append(clip_img)
img_contexts = img_contexts * self.config.n_samples
clip_imgs = clip_imgs * self.config.n_samples
img_contexts = torch.concat(img_contexts, dim=0)
clip_imgs = torch.stack(clip_imgs, dim=0)
return contexts, img_contexts, clip_imgs
@staticmethod
def unpreprocess(v: torch.Tensor) -> torch.Tensor: # to B C H W and [0, 1]
v = 0.5 * (v + 1.)
v.clamp_(0., 1.)
return v
def get_sample_fn(self, _n_samples: int) -> Callable:
def sample_fn(mode: str, **kwargs):
_z_init = torch.randn(_n_samples,
*self.config.z_shape,
device=self.device)
_clip_img_init = torch.randn(_n_samples,
1,
self.config.clip_img_dim,
device=self.device)
_text_init = torch.randn(_n_samples,
77,
self.config.text_dim,
device=self.device)
if mode == 'joint':
_x_init = self.combine_joint(_z_init, _clip_img_init,
_text_init)
elif mode in ['t2i', 'i']:
_x_init = self.combine(_z_init, _clip_img_init)
elif mode in ['i2t', 't']:
_x_init = _text_init
noise_schedule = NoiseScheduleVP(schedule='discrete',
betas=torch.tensor(
self.betas,
device=self.device).float())
def model_fn(x, t_continuous):
t = t_continuous * self.N
if mode == 'joint':
return self.joint_nnet(x, t)
elif mode == 't2i':
return self.t2i_nnet(x, t, **kwargs)
elif mode == 'i2t':
return self.i2t_nnet(x, t, **kwargs)
elif mode == 'i':
return self.i_nnet(x, t)
elif mode == 't':
return self.t_nnet(x, t)
dpm_solver = DPM_Solver(model_fn,
noise_schedule,
predict_x0=True,
thresholding=False)
with torch.inference_mode(), torch.autocast(
device_type=self.device.type):
x = dpm_solver.sample(_x_init,
steps=self.config.sample.sample_steps,
eps=1. / self.N,
T=1.)
if mode == 'joint':
_z, _clip_img, _text = self.split_joint(x)
return _z, _clip_img, _text
elif mode in ['t2i', 'i']:
_z, _clip_img = self.split(x)
return _z, _clip_img
elif mode in ['i2t', 't']:
return x
return sample_fn
@staticmethod
def to_pil(tensor: torch.Tensor) -> PIL.Image.Image:
return PIL.Image.fromarray(
tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
'cpu', torch.uint8).numpy())
def run(self, mode: str, prompt: str, image_path: str, seed: int,
num_steps: int,
guidance_scale: float) -> tuple[PIL.Image.Image | None, str]:
self.config.mode = mode
self.config.prompt = prompt
self.config.img = image_path
self.config.seed = seed
self.config.sample.sample_steps = num_steps
self.config.sample.scale = guidance_scale
self.config.n_samples = 1
#set_seed(self.config.seed)
if seed == -1:
seed = random.randint(0, 1000000)
torch.manual_seed(seed)
contexts, img_contexts, clip_imgs = self.prepare_contexts()
if self.use_caption_decoder:
contexts_low_dim = self.caption_decoder.encode_prefix(contexts)
else:
contexts_low_dim = contexts
z_img = self.autoencoder.sample(img_contexts)
if self.config.mode in ['t2i', 't2i2t']:
_n_samples = contexts_low_dim.size(0)
elif self.config.mode in ['i2t', 'i2t2i']:
_n_samples = img_contexts.size(0)
else:
_n_samples = self.config.n_samples
sample_fn = self.get_sample_fn(_n_samples)
if self.config.mode == 'joint':
_z, _clip_img, _text = sample_fn(self.config.mode)
samples = self.unpreprocess(self.decode(_z))
samples = [self.to_pil(tensor) for tensor in samples]
prompts = self.caption_decoder.generate_captions(_text)
return samples[0], prompts[0]
elif self.config.mode in ['t2i', 'i', 'i2t2i']:
if self.config.mode == 't2i':
_z, _clip_img = sample_fn(
self.config.mode,
text=contexts_low_dim) # conditioned on the text embedding
elif self.config.mode == 'i':
_z, _clip_img = sample_fn(self.config.mode)
elif self.config.mode == 'i2t2i':
_text = sample_fn(
'i2t', z=z_img,
clip_img=clip_imgs) # conditioned on the image embedding
_z, _clip_img = sample_fn('t2i', text=_text)
samples = self.unpreprocess(self.decode(_z))
samples = [self.to_pil(tensor) for tensor in samples]
return samples[0], ''
elif self.config.mode in ['i2t', 't', 't2i2t']:
if self.config.mode == 'i2t':
_text = sample_fn(
self.config.mode, z=z_img,
clip_img=clip_imgs) # conditioned on the image embedding
elif self.config.mode == 't':
_text = sample_fn(self.config.mode)
elif self.config.mode == 't2i2t':
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
_text = sample_fn('i2t', z=_z, clip_img=_clip_img)
prompts = self.caption_decoder.generate_captions(_text)
return None, prompts[0]
else:
raise ValueError