Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Convert LoRA to different rank approximation (should only be used to go to lower rank) | |
| # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py | |
| # Thanks to cloneofsimo | |
| import argparse | |
| import math | |
| import os | |
| import torch | |
| from safetensors.torch import load_file, save_file, safe_open | |
| from tqdm import tqdm | |
| from library import train_util, model_util | |
| import numpy as np | |
| from library.utils import setup_logging | |
| setup_logging() | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def load_state_dict(file_name): | |
| if model_util.is_safetensors(file_name): | |
| sd = load_file(file_name) | |
| with safe_open(file_name, framework="pt") as f: | |
| metadata = f.metadata() | |
| else: | |
| sd = torch.load(file_name, map_location="cpu") | |
| metadata = None | |
| return sd, metadata | |
| def save_to_file(file_name, model, metadata): | |
| if model_util.is_safetensors(file_name): | |
| save_file(model, file_name, metadata) | |
| else: | |
| torch.save(model, file_name) | |
| def split_lora_model(lora_sd, unit): | |
| max_rank = 0 | |
| # Extract loaded lora dim and alpha | |
| for key, value in lora_sd.items(): | |
| if "lora_down" in key: | |
| rank = value.size()[0] | |
| if rank > max_rank: | |
| max_rank = rank | |
| logger.info(f"Max rank: {max_rank}") | |
| rank = unit | |
| split_models = [] | |
| new_alpha = None | |
| while rank < max_rank: | |
| logger.info(f"Splitting rank {rank}") | |
| new_sd = {} | |
| for key, value in lora_sd.items(): | |
| if "lora_down" in key: | |
| new_sd[key] = value[:rank].contiguous() | |
| elif "lora_up" in key: | |
| new_sd[key] = value[:, :rank].contiguous() | |
| else: | |
| # なぜかscaleするとおかしくなる…… | |
| # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] | |
| # scale = math.sqrt(this_rank / rank) # rank is > unit | |
| # logger.info(key, value.size(), this_rank, rank, value, scale) | |
| # new_alpha = value * scale # always same | |
| # new_sd[key] = new_alpha | |
| new_sd[key] = value | |
| split_models.append((new_sd, rank, new_alpha)) | |
| rank += unit | |
| return max_rank, split_models | |
| def split(args): | |
| logger.info("loading Model...") | |
| lora_sd, metadata = load_state_dict(args.model) | |
| logger.info("Splitting Model...") | |
| original_rank, split_models = split_lora_model(lora_sd, args.unit) | |
| comment = metadata.get("ss_training_comment", "") | |
| for state_dict, new_rank, new_alpha in split_models: | |
| # update metadata | |
| if metadata is None: | |
| new_metadata = {} | |
| else: | |
| new_metadata = metadata.copy() | |
| new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" | |
| new_metadata["ss_network_dim"] = str(new_rank) | |
| # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) | |
| model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) | |
| metadata["sshs_model_hash"] = model_hash | |
| metadata["sshs_legacy_hash"] = legacy_hash | |
| filename, ext = os.path.splitext(args.save_to) | |
| model_file_name = filename + f"-{new_rank:04d}{ext}" | |
| logger.info(f"saving model to: {model_file_name}") | |
| save_to_file(model_file_name, state_dict, new_metadata) | |
| def setup_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") | |
| parser.add_argument( | |
| "--save_to", | |
| type=str, | |
| default=None, | |
| help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default=None, | |
| help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| parser = setup_parser() | |
| args = parser.parse_args() | |
| split(args) | |