adaface-animate / adaface /unet_teachers.py
adaface-neurips
Increase ID-Animator Image Embedding scales
61fbdeb
raw
history blame
11.5 kB
import torch
import numpy as np
import pytorch_lightning as pl
from diffusers import UNet2DConditionModel
from adaface.util import UNetEnsemble, create_consistentid_pipeline
from diffusers import UNet2DConditionModel
from omegaconf.listconfig import ListConfig
def create_unet_teacher(teacher_type, device='cpu', **kwargs):
# If teacher_type is a list with only one element, we dereference it.
if isinstance(teacher_type, (tuple, list, ListConfig)) and len(teacher_type) == 1:
teacher_type = teacher_type[0]
if teacher_type == "arc2face":
return Arc2FaceTeacher(**kwargs)
elif teacher_type == "unet_ensemble":
# unet, extra_unet_dirpaths and unet_weights are passed in kwargs.
# Even if we distill from unet_ensemble, we still need to load arc2face for generating
# arc2face embeddings.
# The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet,
# in our case, the ddpm unet. Ideally we should reuse it to save GPU RAM.
# However, since the __call__ method of the ddpm unet takes different formats of params,
# for simplicity, we still use the diffusers unet.
# unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU.
return UNetEnsembleTeacher(device=device, **kwargs)
elif teacher_type == "consistentID":
return ConsistentIDTeacher(**kwargs)
elif teacher_type == "simple_unet":
return SimpleUNetTeacher(**kwargs)
# Since we've dereferenced the list if it has only one element,
# this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher.
elif isinstance(teacher_type, (tuple, list, ListConfig)):
# teacher_type is a list of teacher types. So it's UNetEnsembleTeacher.
return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs)
else:
raise NotImplementedError(f"Teacher type {teacher_type} not implemented.")
class UNetTeacher(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.name = None
# self.unet will be initialized in the child class.
self.unet = None
self.p_uses_cfg = kwargs.get("p_uses_cfg", 0)
# self.cfg_scale will be randomly sampled from cfg_scale_range.
self.cfg_scale_range = kwargs.get("cfg_scale_range", [1.3, 2])
# Initialize cfg_scale to 1. It will be randomly sampled during forward pass.
self.cfg_scale = 1
if self.p_uses_cfg > 0:
print(f"Using CFG with probability {self.p_uses_cfg} and scale range {self.cfg_scale_range}.")
else:
print(f"Never using CFG.")
# Passing in ddpm_model to use its q_sample and predict_start_from_noise methods.
# We don't implement the two functions here, because they involve a few tensors
# to be initialized, which will unnecessarily complicate the code.
# noise: the initial noise for the first iteration.
# t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t.
# uses_same_t: when sampling t, use the same t for all instances.
def forward(self, ddpm_model, x_start, noise, t, teacher_context,
num_denoising_steps=1, uses_same_t=False):
assert num_denoising_steps <= 10
if self.p_uses_cfg > 0:
self.uses_cfg = np.random.rand() < self.p_uses_cfg
if self.uses_cfg:
# Randomly sample a cfg_scale from cfg_scale_range.
self.cfg_scale = np.random.uniform(*self.cfg_scale_range)
if self.cfg_scale == 1:
self.uses_cfg = False
if self.uses_cfg:
print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.")
else:
self.cfg_scale = 1
print("Teacher does not use CFG.")
# If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher.
# But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1.
# So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context.
if self.name == 'unet_ensemble':
teacher_pos_contexts = []
# teacher_context is a list of teacher contexts.
for teacher_context_i in teacher_context:
pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0)
if pos_context.shape[0] != x_start.shape[0]:
breakpoint()
teacher_pos_contexts.append(pos_context)
teacher_context = teacher_pos_contexts
else:
pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0)
if pos_context.shape[0] != x_start.shape[0]:
breakpoint()
teacher_context = pos_context
else:
# p_uses_cfg = 0. Never use CFG.
self.uses_cfg = False
# In this case, the student only passes pos_context to the teacher,
# so no need to split teacher_context into pos_context and neg_context.
# self.cfg_scale will be accessed by the student,
# so we need to make sure it is always set correctly,
# in case someday we want to switch from CFG to non-CFG during runtime.
self.cfg_scale = 1
if self.name == 'unet_ensemble':
# teacher_context is a list of teacher contexts.
for teacher_context_i in teacher_context:
if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
breakpoint()
else:
if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg):
breakpoint()
# Initially, x_starts only contains the original x_start.
x_starts = [ x_start ]
noises = [ noise ]
ts = [ t ]
noise_preds = []
with torch.autocast(device_type='cuda', dtype=torch.float16):
for i in range(num_denoising_steps):
x_start = x_starts[i]
t = ts[i]
noise = noises[i]
# sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise
x_noisy = ddpm_model.q_sample(x_start, t, noise)
if self.uses_cfg:
x_noisy2 = x_noisy.repeat(2, 1, 1, 1)
t2 = t.repeat(2)
else:
x_noisy2 = x_noisy
t2 = t
# If do_arc2face_distill, then pos_context is [BS=6, 21, 768].
noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context,
return_dict=False)[0]
if self.uses_cfg and self.cfg_scale > 1:
pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1)
# sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise
pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred)
noise_preds.append(noise_pred)
# The predicted x0 is used as the x_start for the next denoising step.
x_starts.append(pred_x0)
# Sample an earlier timestep for the next denoising step.
if i < num_denoising_steps - 1:
# NOTE: rand_like() samples from U(0, 1), not like randn_like().
relative_ts = torch.rand_like(t.float())
# Make sure at the middle step (i = sqrt(num_denoising_steps - 1), the timestep
# is between 50% and 70% of the current timestep. So if num_denoising_steps = 5,
# we take timesteps within [0.5^0.66, 0.7^0.66] = [0.63, 0.79] of the current timestep.
# If num_denoising_steps = 4, we take timesteps within [0.5^0.72, 0.7^0.72] = [0.61, 0.77]
# of the current timestep.
t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3))
t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3))
earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb
earlier_timesteps = earlier_timesteps.long()
if uses_same_t:
# If uses_same_t, we use the same earlier_timesteps for all instances.
earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0])
# earlier_timesteps = ts[i+1] < ts[i].
ts.append(earlier_timesteps)
noise = torch.randn_like(pred_x0)
noises.append(noise)
return noise_preds, x_starts, noises, ts
class Arc2FaceTeacher(UNetTeacher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "arc2face"
self.unet = UNet2DConditionModel.from_pretrained(
#"runwayml/stable-diffusion-v1-5", subfolder="unet"
'models/arc2face', subfolder="arc2face", torch_dtype=torch.float16
)
# Disable CFG. Even if p_uses_cfg > 0, the randomly drawn cfg_scale is still 1,
# so the CFG is effectively disabled.
self.cfg_scale_range = [1, 1]
class UNetEnsembleTeacher(UNetTeacher):
# unet_weights are not model weights, but scalar weights for individual unets.
def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs):
super().__init__(**kwargs)
self.name = "unet_ensemble"
self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device)
class ConsistentIDTeacher(UNetTeacher):
def __init__(self, base_model_path="models/sd15-dste8-vae.safetensors", **kwargs):
super().__init__(**kwargs)
self.name = "consistentID"
### Load base model
# In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module.
# We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU.
# Instead, we have to initialize it to GPU directly.
pipe = create_consistentid_pipeline(base_model_path)
# Compatible with the UNetTeacher interface.
self.unet = pipe.unet
# Release VAE and text_encoder to save memory. UNet is still needed for denoising
# (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet).
pipe.release_components(["vae", "text_encoder"])
# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher.
# Note p_uses_cfg=0.5 will also be passed in in kwargs.
class SimpleUNetTeacher(UNetTeacher):
def __init__(self, unet_dirpath='models/ensemble/sd15-unet',
torch_dtype=torch.float16, **kwargs):
super().__init__(**kwargs)
self.name = "simple_unet"
self.unet = UNet2DConditionModel.from_pretrained(
unet_dirpath, torch_dtype=torch_dtype
)