Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import pytorch_lightning as pl | |
from diffusers import UNet2DConditionModel | |
from adaface.util import UNetEnsemble, create_consistentid_pipeline | |
from diffusers import UNet2DConditionModel | |
from omegaconf.listconfig import ListConfig | |
def create_unet_teacher(teacher_type, device='cpu', **kwargs): | |
# If teacher_type is a list with only one element, we dereference it. | |
if isinstance(teacher_type, (tuple, list, ListConfig)) and len(teacher_type) == 1: | |
teacher_type = teacher_type[0] | |
if teacher_type == "arc2face": | |
return Arc2FaceTeacher(**kwargs) | |
elif teacher_type == "unet_ensemble": | |
# unet, extra_unet_dirpaths and unet_weights are passed in kwargs. | |
# Even if we distill from unet_ensemble, we still need to load arc2face for generating | |
# arc2face embeddings. | |
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet, | |
# in our case, the ddpm unet. Ideally we should reuse it to save GPU RAM. | |
# However, since the __call__ method of the ddpm unet takes different formats of params, | |
# for simplicity, we still use the diffusers unet. | |
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU. | |
return UNetEnsembleTeacher(device=device, **kwargs) | |
elif teacher_type == "consistentID": | |
return ConsistentIDTeacher(**kwargs) | |
elif teacher_type == "simple_unet": | |
return SimpleUNetTeacher(**kwargs) | |
# Since we've dereferenced the list if it has only one element, | |
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher. | |
elif isinstance(teacher_type, (tuple, list, ListConfig)): | |
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher. | |
return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs) | |
else: | |
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.") | |
class UNetTeacher(pl.LightningModule): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.name = None | |
# self.unet will be initialized in the child class. | |
self.unet = None | |
self.p_uses_cfg = kwargs.get("p_uses_cfg", 0) | |
# self.cfg_scale will be randomly sampled from cfg_scale_range. | |
self.cfg_scale_range = kwargs.get("cfg_scale_range", [1.3, 2]) | |
# Initialize cfg_scale to 1. It will be randomly sampled during forward pass. | |
self.cfg_scale = 1 | |
if self.p_uses_cfg > 0: | |
print(f"Using CFG with probability {self.p_uses_cfg} and scale range {self.cfg_scale_range}.") | |
else: | |
print(f"Never using CFG.") | |
# Passing in ddpm_model to use its q_sample and predict_start_from_noise methods. | |
# We don't implement the two functions here, because they involve a few tensors | |
# to be initialized, which will unnecessarily complicate the code. | |
# noise: the initial noise for the first iteration. | |
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t. | |
# uses_same_t: when sampling t, use the same t for all instances. | |
def forward(self, ddpm_model, x_start, noise, t, teacher_context, | |
num_denoising_steps=1, uses_same_t=False): | |
assert num_denoising_steps <= 10 | |
if self.p_uses_cfg > 0: | |
self.uses_cfg = np.random.rand() < self.p_uses_cfg | |
if self.uses_cfg: | |
# Randomly sample a cfg_scale from cfg_scale_range. | |
self.cfg_scale = np.random.uniform(*self.cfg_scale_range) | |
if self.cfg_scale == 1: | |
self.uses_cfg = False | |
if self.uses_cfg: | |
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.") | |
else: | |
self.cfg_scale = 1 | |
print("Teacher does not use CFG.") | |
# If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher. | |
# But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1. | |
# So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context. | |
if self.name == 'unet_ensemble': | |
teacher_pos_contexts = [] | |
# teacher_context is a list of teacher contexts. | |
for teacher_context_i in teacher_context: | |
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0) | |
if pos_context.shape[0] != x_start.shape[0]: | |
breakpoint() | |
teacher_pos_contexts.append(pos_context) | |
teacher_context = teacher_pos_contexts | |
else: | |
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0) | |
if pos_context.shape[0] != x_start.shape[0]: | |
breakpoint() | |
teacher_context = pos_context | |
else: | |
# p_uses_cfg = 0. Never use CFG. | |
self.uses_cfg = False | |
# In this case, the student only passes pos_context to the teacher, | |
# so no need to split teacher_context into pos_context and neg_context. | |
# self.cfg_scale will be accessed by the student, | |
# so we need to make sure it is always set correctly, | |
# in case someday we want to switch from CFG to non-CFG during runtime. | |
self.cfg_scale = 1 | |
if self.name == 'unet_ensemble': | |
# teacher_context is a list of teacher contexts. | |
for teacher_context_i in teacher_context: | |
if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): | |
breakpoint() | |
else: | |
if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): | |
breakpoint() | |
# Initially, x_starts only contains the original x_start. | |
x_starts = [ x_start ] | |
noises = [ noise ] | |
ts = [ t ] | |
noise_preds = [] | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
for i in range(num_denoising_steps): | |
x_start = x_starts[i] | |
t = ts[i] | |
noise = noises[i] | |
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise | |
x_noisy = ddpm_model.q_sample(x_start, t, noise) | |
if self.uses_cfg: | |
x_noisy2 = x_noisy.repeat(2, 1, 1, 1) | |
t2 = t.repeat(2) | |
else: | |
x_noisy2 = x_noisy | |
t2 = t | |
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768]. | |
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context, | |
return_dict=False)[0] | |
if self.uses_cfg and self.cfg_scale > 1: | |
pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0) | |
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1) | |
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise | |
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred) | |
noise_preds.append(noise_pred) | |
# The predicted x0 is used as the x_start for the next denoising step. | |
x_starts.append(pred_x0) | |
# Sample an earlier timestep for the next denoising step. | |
if i < num_denoising_steps - 1: | |
# NOTE: rand_like() samples from U(0, 1), not like randn_like(). | |
relative_ts = torch.rand_like(t.float()) | |
# Make sure at the middle step (i = sqrt(num_denoising_steps - 1), the timestep | |
# is between 50% and 70% of the current timestep. So if num_denoising_steps = 5, | |
# we take timesteps within [0.5^0.66, 0.7^0.66] = [0.63, 0.79] of the current timestep. | |
# If num_denoising_steps = 4, we take timesteps within [0.5^0.72, 0.7^0.72] = [0.61, 0.77] | |
# of the current timestep. | |
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3)) | |
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3)) | |
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb | |
earlier_timesteps = earlier_timesteps.long() | |
if uses_same_t: | |
# If uses_same_t, we use the same earlier_timesteps for all instances. | |
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0]) | |
# earlier_timesteps = ts[i+1] < ts[i]. | |
ts.append(earlier_timesteps) | |
noise = torch.randn_like(pred_x0) | |
noises.append(noise) | |
return noise_preds, x_starts, noises, ts | |
class Arc2FaceTeacher(UNetTeacher): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.name = "arc2face" | |
self.unet = UNet2DConditionModel.from_pretrained( | |
#"runwayml/stable-diffusion-v1-5", subfolder="unet" | |
'models/arc2face', subfolder="arc2face", torch_dtype=torch.float16 | |
) | |
# Disable CFG. Even if p_uses_cfg > 0, the randomly drawn cfg_scale is still 1, | |
# so the CFG is effectively disabled. | |
self.cfg_scale_range = [1, 1] | |
class UNetEnsembleTeacher(UNetTeacher): | |
# unet_weights are not model weights, but scalar weights for individual unets. | |
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs): | |
super().__init__(**kwargs) | |
self.name = "unet_ensemble" | |
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device) | |
class ConsistentIDTeacher(UNetTeacher): | |
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs): | |
super().__init__(**kwargs) | |
self.name = "consistentID" | |
### Load base model | |
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module. | |
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU. | |
# Instead, we have to initialize it to GPU directly. | |
pipe = create_consistentid_pipeline(base_model_path) | |
# Compatible with the UNetTeacher interface. | |
self.unet = pipe.unet | |
# Release VAE and text_encoder to save memory. UNet is still needed for denoising | |
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet). | |
pipe.release_components(["vae", "text_encoder"]) | |
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher. | |
# Note p_uses_cfg=0.5 will also be passed in in kwargs. | |
class SimpleUNetTeacher(UNetTeacher): | |
def __init__(self, unet_dirpath='models/ensemble/sd15-unet', | |
torch_dtype=torch.float16, **kwargs): | |
super().__init__(**kwargs) | |
self.name = "simple_unet" | |
self.unet = UNet2DConditionModel.from_pretrained( | |
unet_dirpath, torch_dtype=torch_dtype | |
) | |