import os, sys sys.path.insert(0, os.getcwd()) import argparse def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "base_model", help="The model you want to merge with loha", default='', type=str ) parser.add_argument( "lycoris_model", help="the lyco model you want to merge into sd model", default='', type=str ) parser.add_argument( "output_name", help="the output model", default='./out.pt', type=str ) parser.add_argument( "--is_v2", help="Your base model is sd v2 or not", default=False, action="store_true" ) parser.add_argument( "--device", help="Which device you want to use to merge the weight", default='cpu', type=str ) parser.add_argument( "--dtype", help='dtype to save', default='float', type=str ) parser.add_argument( "--weight", help='weight for the lyco model to merge', default='1.0', type=float ) return parser.parse_args() ARGS = get_args() from lycoris_utils import merge from lycoris.kohya_model_utils import ( load_models_from_stable_diffusion_checkpoint, save_stable_diffusion_checkpoint, load_file ) import torch def main(): base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model) if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors': lyco = load_file(ARGS.lycoris_model) else: lyco = torch.load(ARGS.lycoris_model) dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat') dtype = { 'float': torch.float, 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, 'bfloat': torch.bfloat16, 'bfloat16': torch.bfloat16, }.get(dtype_str, None) if dtype is None: raise ValueError(f'Cannot Find the dtype "{dtype}"') merge( base, lyco, ARGS.weight, ARGS.device ) save_stable_diffusion_checkpoint( ARGS.is_v2, ARGS.output_name, base[0], base[2], None, 0, 0, dtype, base[1] ) if __name__ == '__main__': main()