File size: 749 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import logging

from hcpdiff.ckpt_manager import auto_manager
from hcpdiff.tools.lora_convert import LoraConverter


def convert_to_webui_lora(lora_path, lora_path_TE, dump_path, auto_scale_alpha: bool = True):
    converter = LoraConverter()

    # load lora model
    logging.info(f'Converting lora model {lora_path!r} and {lora_path_TE!r} to {dump_path!r} ...')
    ckpt_manager = auto_manager(lora_path)()

    sd_unet = ckpt_manager.load_ckpt(lora_path)
    sd_TE = ckpt_manager.load_ckpt(lora_path_TE)
    state = converter.convert_to_webui(sd_unet['lora'], sd_TE['lora'])
    if auto_scale_alpha:
        state = {k: (v * v.shape[1] if 'lora_up' in k else v) for k, v in state.items()}
    ckpt_manager._save_ckpt(state, save_path=dump_path)