AnimateLCM / animatelcm /utils /lcm_utils.py
wangfuyun's picture
Upload 42 files
c0fdaf5 verified
raw
history blame
No virus
9.53 kB
import torch
import numpy as np
from safetensors import safe_open
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "v_prediction":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sample = sample * alphas / sigmas
else:
raise ValueError(
f"Prediction type {prediction_type} currently not supported.")
return sample
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (
np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
# self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
self.ddim_timesteps_prev = np.asarray(
[0] + self.ddim_timesteps[:-1].tolist()
)
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_timesteps_prev = torch.from_numpy(
self.ddim_timesteps_prev).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(
self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(
self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def convert_lcm_lora(unet, path, alpha=1.0):
if path.endswith(("ckpt",)):
state_dict = torch.load(path, map_location="cpu")
else:
state_dict = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
num_alpha = 0
for key in state_dict.keys():
if "alpha" in key:
num_alpha += 1
lora_keys = [k for k in state_dict.keys(
) if k.endswith("lora_down.weight")]
updated_state_dict = {}
for key in lora_keys:
lora_name = key.split(".")[0]
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace(
"input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace(
"down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace(
"middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace(
"mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace(
"output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace(
"up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace(
"transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace(
"to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace(
"time.emb.proj", "time_emb_proj")
diffusers_name = diffusers_name.replace(
"conv.shortcut", "conv_shortcut")
updated_state_dict[diffusers_name] = state_dict[key]
up_diffusers_name = diffusers_name.replace(".down.", ".up.")
up_key = key.replace("lora_down.weight", "lora_up.weight")
updated_state_dict[up_diffusers_name] = state_dict[up_key]
state_dict = updated_state_dict
num_lora = 0
for key in state_dict:
if "up." in key:
continue
up_key = key.replace(".down.", ".up.")
model_key = key.replace("processor.", "").replace("_lora", "").replace(
"down.", "").replace("up.", "").replace(".lora", "")
model_key = model_key.replace("to_out.", "to_out.0.")
layer_infos = model_key.split(".")[:-1]
curr_layer = unet
while len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
curr_layer = curr_layer.__getattr__(temp_name)
weight_down = state_dict[key].to(
curr_layer.weight.data.device, curr_layer.weight.data.dtype)
weight_up = state_dict[up_key].to(
curr_layer.weight.data.device, curr_layer.weight.data.dtype)
if weight_up.ndim == 2:
curr_layer.weight.data += 1/8 * alpha * \
torch.mm(weight_up, weight_down)
else:
assert weight_up.ndim == 4
curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
num_lora += 1
return unet