from typing import * import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.linalg as linalg from tqdm import tqdm def make_sparse(t: torch.Tensor, sparsity=0.95): abs_t = torch.abs(t) np_array = abs_t.detach().cpu().numpy() quan = float(np.quantile(np_array, sparsity)) sparse_t = t.masked_fill(abs_t < quan, 0) return sparse_t def extract_conv( weight: Union[torch.Tensor, nn.Parameter], mode = 'fixed', mode_param = 0, device = 'cpu', ) -> Tuple[nn.Parameter, nn.Parameter]: weight = weight.to(device) out_ch, in_ch, kernel_size, _ = weight.shape U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) if mode=='fixed': lora_rank = mode_param elif mode=='threshold': assert mode_param>=0 lora_rank = torch.sum(S>mode_param) elif mode=='ratio': assert 1>=mode_param>=0 min_s = torch.max(S)*mode_param lora_rank = torch.sum(S>min_s) elif mode=='quantile' or mode=='percentile': assert 1>=mode_param>=0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum Tuple[nn.Parameter, nn.Parameter]: weight = weight.to(device) out_ch, in_ch = weight.shape U, S, Vh = linalg.svd(weight) if mode=='fixed': lora_rank = mode_param elif mode=='threshold': assert mode_param>=0 lora_rank = torch.sum(S>mode_param) elif mode=='ratio': assert 1>=mode_param>=0 min_s = torch.max(S)*mode_param lora_rank = torch.sum(S>min_s) elif mode=='quantile' or mode=='percentile': assert 1>=mode_param>=0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum p r k l', extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1) ).detach().cpu().contiguous() del extract_c else: continue loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half() loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half() loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half() if use_bias: diff = diff.detach().cpu().reshape(extract_b.size(0), -1) sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() indices = sparse_diff.indices().to(torch.int16) values = sparse_diff.values().half() loras[f'{lora_name}.bias_indices'] = indices loras[f'{lora_name}.bias_values'] = values loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16) del extract_a, extract_b, diff return loras text_encoder_loras = make_state_dict( LORA_PREFIX_TEXT_ENCODER, base_model[0], db_model[0], TEXT_ENCODER_TARGET_REPLACE_MODULE ) unet_loras = make_state_dict( LORA_PREFIX_UNET, base_model[2], db_model[2], UNET_TARGET_REPLACE_MODULE ) print(len(text_encoder_loras), len(unet_loras)) return text_encoder_loras|unet_loras def merge_locon( base_model, locon_state_dict: Dict[str, torch.TensorType], scale: float = 1.0, device = 'cpu' ): UNET_TARGET_REPLACE_MODULE = [ "Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D" ] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' def merge( prefix, root_module: torch.nn.Module, target_replace_modules ): temp = {} for name, module in tqdm(list(root_module.named_modules())): if module.__class__.__name__ in target_replace_modules: temp[name] = {} for child_name, child_module in module.named_modules(): layer = child_module.__class__.__name__ if layer not in {'Linear', 'Conv2d'}: continue lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') down = locon_state_dict[f'{lora_name}.lora_down.weight'].float() up = locon_state_dict[f'{lora_name}.lora_up.weight'].float() alpha = locon_state_dict[f'{lora_name}.alpha'].float() rank = down.shape[0] if layer == 'Conv2d': delta = merge_conv(down, up, device) child_module.weight.requires_grad_(False) child_module.weight += (alpha.to(device)/rank * scale * delta).cpu() del delta elif layer == 'Linear': delta = merge_linear(down, up, device) child_module.weight.requires_grad_(False) child_module.weight += (alpha.to(device)/rank * scale * delta).cpu() del delta merge( LORA_PREFIX_TEXT_ENCODER, base_model[0], TEXT_ENCODER_TARGET_REPLACE_MODULE ) merge( LORA_PREFIX_UNET, base_model[2], UNET_TARGET_REPLACE_MODULE ) def merge_loha( base_model, loha_state_dict: Dict[str, torch.TensorType], scale: float = 1.0, device = 'cpu' ): UNET_TARGET_REPLACE_MODULE = [ "Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D" ] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' def merge( prefix, root_module: torch.nn.Module, target_replace_modules ): temp = {} for name, module in tqdm(list(root_module.named_modules())): if module.__class__.__name__ in target_replace_modules: temp[name] = {} for child_name, child_module in module.named_modules(): layer = child_module.__class__.__name__ if layer not in {'Linear', 'Conv2d'}: continue lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') w1a = loha_state_dict[f'{lora_name}.hada_w1_a'].float().to(device) w1b = loha_state_dict[f'{lora_name}.hada_w1_b'].float().to(device) w2a = loha_state_dict[f'{lora_name}.hada_w2_a'].float().to(device) w2b = loha_state_dict[f'{lora_name}.hada_w2_b'].float().to(device) alpha = loha_state_dict[f'{lora_name}.alpha'].float().to(device) dim = w1b.shape[0] delta = (w1a @ w1b) * (w2a @ w2b) delta = delta.reshape(child_module.weight.shape) if layer == 'Conv2d': child_module.weight.requires_grad_(False) child_module.weight += (alpha.to(device)/dim * scale * delta).cpu() elif layer == 'Linear': child_module.weight.requires_grad_(False) child_module.weight += (alpha.to(device)/dim * scale * delta).cpu() del delta merge( LORA_PREFIX_TEXT_ENCODER, base_model[0], TEXT_ENCODER_TARGET_REPLACE_MODULE ) merge( LORA_PREFIX_UNET, base_model[2], UNET_TARGET_REPLACE_MODULE )