haodongli's picture
init
916b126
raw
history blame
4.42 kB
import fire
from diffusers import StableDiffusionPipeline
import torch
import torch.nn as nn
from .lora import (
save_all,
_find_modules,
LoraInjectedConv2d,
LoraInjectedLinear,
inject_trainable_lora,
inject_trainable_lora_extended,
)
def _iter_lora(model):
for module in model.modules():
if isinstance(module, LoraInjectedConv2d) or isinstance(
module, LoraInjectedLinear
):
yield module
def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
device = base_model.device
dtype = base_model.dtype
for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
if isinstance(lor_base, LoraInjectedLinear):
residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
# SVD on residual
print("Distill Linear shape ", residual.shape)
residual = residual.float()
U, S, Vh = torch.linalg.svd(residual)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
assert lor_base.lora_up.weight.shape == U.shape
assert lor_base.lora_down.weight.shape == Vh.shape
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
if isinstance(lor_base, LoraInjectedConv2d):
residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
print("Distill Conv shape ", residual.shape)
residual = residual.float()
residual = residual.flatten(start_dim=1)
# SVD on residual
U, S, Vh = torch.linalg.svd(residual)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
# U is (out_channels, rank) with 1x1 conv. So,
U = U.reshape(U.shape[0], U.shape[1], 1, 1)
# V is (rank, in_channels * kernel_size1 * kernel_size2)
# now reshape:
Vh = Vh.reshape(
Vh.shape[0],
lor_base.conv.in_channels,
lor_base.conv.kernel_size[0],
lor_base.conv.kernel_size[1],
)
assert lor_base.lora_up.weight.shape == U.shape
assert lor_base.lora_down.weight.shape == Vh.shape
lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
def svd_distill(
target_model: str,
base_model: str,
rank: int = 4,
clamp_quantile: float = 0.99,
device: str = "cuda:0",
save_path: str = "svd_distill.safetensors",
):
pipe_base = StableDiffusionPipeline.from_pretrained(
base_model, torch_dtype=torch.float16
).to(device)
pipe_tuned = StableDiffusionPipeline.from_pretrained(
target_model, torch_dtype=torch.float16
).to(device)
# Inject unet
_ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
_ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
overwrite_base(
pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
)
# Inject text encoder
_ = inject_trainable_lora(
pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
)
_ = inject_trainable_lora(
pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
)
overwrite_base(
pipe_base.text_encoder,
pipe_tuned.text_encoder,
rank=rank,
clamp_quantile=clamp_quantile,
)
save_all(
unet=pipe_base.unet,
text_encoder=pipe_base.text_encoder,
placeholder_token_ids=None,
placeholder_tokens=None,
save_path=save_path,
save_lora=True,
save_ti=False,
)
def main():
fire.Fire(svd_distill)