import fire from diffusers import StableDiffusionPipeline import torch import torch.nn as nn from .lora import ( save_all, _find_modules, LoraInjectedConv2d, LoraInjectedLinear, inject_trainable_lora, inject_trainable_lora_extended, ) def _iter_lora(model): for module in model.modules(): if isinstance(module, LoraInjectedConv2d) or isinstance( module, LoraInjectedLinear ): yield module def overwrite_base(base_model, tuned_model, rank, clamp_quantile): device = base_model.device dtype = base_model.dtype for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)): if isinstance(lor_base, LoraInjectedLinear): residual = lor_tune.linear.weight.data - lor_base.linear.weight.data # SVD on residual print("Distill Linear shape ", residual.shape) residual = residual.float() U, S, Vh = torch.linalg.svd(residual) U = U[:, :rank] S = S[:rank] U = U @ torch.diag(S) Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) assert lor_base.lora_up.weight.shape == U.shape assert lor_base.lora_down.weight.shape == Vh.shape lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) if isinstance(lor_base, LoraInjectedConv2d): residual = lor_tune.conv.weight.data - lor_base.conv.weight.data print("Distill Conv shape ", residual.shape) residual = residual.float() residual = residual.flatten(start_dim=1) # SVD on residual U, S, Vh = torch.linalg.svd(residual) U = U[:, :rank] S = S[:rank] U = U @ torch.diag(S) Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) # U is (out_channels, rank) with 1x1 conv. So, U = U.reshape(U.shape[0], U.shape[1], 1, 1) # V is (rank, in_channels * kernel_size1 * kernel_size2) # now reshape: Vh = Vh.reshape( Vh.shape[0], lor_base.conv.in_channels, lor_base.conv.kernel_size[0], lor_base.conv.kernel_size[1], ) assert lor_base.lora_up.weight.shape == U.shape assert lor_base.lora_down.weight.shape == Vh.shape lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) def svd_distill( target_model: str, base_model: str, rank: int = 4, clamp_quantile: float = 0.99, device: str = "cuda:0", save_path: str = "svd_distill.safetensors", ): pipe_base = StableDiffusionPipeline.from_pretrained( base_model, torch_dtype=torch.float16 ).to(device) pipe_tuned = StableDiffusionPipeline.from_pretrained( target_model, torch_dtype=torch.float16 ).to(device) # Inject unet _ = inject_trainable_lora_extended(pipe_base.unet, r=rank) _ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank) overwrite_base( pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile ) # Inject text encoder _ = inject_trainable_lora( pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"} ) _ = inject_trainable_lora( pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"} ) overwrite_base( pipe_base.text_encoder, pipe_tuned.text_encoder, rank=rank, clamp_quantile=clamp_quantile, ) save_all( unet=pipe_base.unet, text_encoder=pipe_base.text_encoder, placeholder_token_ids=None, placeholder_tokens=None, save_path=save_path, save_lora=True, save_ti=False, ) def main(): fire.Fire(svd_distill)