Spaces:
Running
Running
| # currently only works with flux as support is not quite there yet | |
| import argparse | |
| import os.path | |
| from collections import OrderedDict | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| 'input_path', | |
| type=str, | |
| help='Path to original sdxl model' | |
| ) | |
| parser.add_argument( | |
| 'output_path', | |
| type=str, | |
| help='output path' | |
| ) | |
| args = parser.parse_args() | |
| args.input_path = os.path.abspath(args.input_path) | |
| args.output_path = os.path.abspath(args.output_path) | |
| from safetensors.torch import load_file, save_file | |
| meta = OrderedDict() | |
| meta['format'] = 'pt' | |
| state_dict = load_file(args.input_path) | |
| # peft doesnt have an alpha so we need to scale the weights | |
| alpha_keys = [ | |
| 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux | |
| ] | |
| # keys where the rank is in the first dimension | |
| rank_idx0_keys = [ | |
| 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight' | |
| # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight' | |
| ] | |
| alpha = None | |
| rank = None | |
| for key in rank_idx0_keys: | |
| if key in state_dict: | |
| rank = int(state_dict[key].shape[0]) | |
| break | |
| if rank is None: | |
| raise ValueError(f'Could not find rank in state dict') | |
| for key in alpha_keys: | |
| if key in state_dict: | |
| alpha = int(state_dict[key]) | |
| break | |
| if alpha is None: | |
| # set to rank if not found | |
| alpha = rank | |
| up_multiplier = alpha / rank | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.endswith('.alpha'): | |
| continue | |
| orig_dtype = value.dtype | |
| new_val = value.float() * up_multiplier | |
| new_key = key | |
| new_key = new_key.replace('lora_transformer_', 'transformer.') | |
| for i in range(100): | |
| new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.') | |
| new_key = new_key.replace('lora_down', 'lora_A') | |
| new_key = new_key.replace('lora_up', 'lora_B') | |
| new_key = new_key.replace('_lora', '.lora') | |
| new_key = new_key.replace('attn_', 'attn.') | |
| new_key = new_key.replace('ff_', 'ff.') | |
| new_key = new_key.replace('context_net_', 'context.net.') | |
| new_key = new_key.replace('0_proj', '0.proj') | |
| new_key = new_key.replace('norm_linear', 'norm.linear') | |
| new_key = new_key.replace('norm_out_linear', 'norm_out.linear') | |
| new_key = new_key.replace('to_out_', 'to_out.') | |
| new_state_dict[new_key] = new_val.to(orig_dtype) | |
| save_file(new_state_dict, args.output_path, meta) | |
| print(f'Saved to {args.output_path}') | |