import os import fire import torch from lora_diffusion import ( DEFAULT_TARGET_REPLACE, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, UNET_DEFAULT_TARGET_REPLACE, convert_loras_to_safeloras_with_embeds, safetensors_available, ) _target_by_name = { "unet": UNET_DEFAULT_TARGET_REPLACE, "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, } def convert(*paths, outpath, overwrite=False, **settings): """ Converts one or more pytorch Lora and/or Textual Embedding pytorch files into a safetensor file. Pass all the input paths as arguments. Whether they are Textual Embedding or Lora models will be auto-detected. For Lora models, their name will be taken from the path, i.e. "lora_weight.pt" => unet "lora_weight.text_encoder.pt" => text_encoder You can also set target_modules and/or rank by providing an argument prefixed by the name. So a complete example might be something like: ``` python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 ``` """ modelmap = {} embeds = {} if os.path.exists(outpath) and not overwrite: raise ValueError( f"Output path {outpath} already exists, and overwrite is not True" ) for path in paths: data = torch.load(path) if isinstance(data, dict): print(f"Loading textual inversion embeds {data.keys()} from {path}") embeds.update(data) else: name_parts = os.path.split(path)[1].split(".") name = name_parts[-2] if len(name_parts) > 2 else "unet" model_settings = { "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), "rank": 4, } prefix = f"{name}." arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) } model_settings = { **model_settings, **arg_settings } print(f"Loading Lora for {name} from {path} with settings {model_settings}") modelmap[name] = ( path, model_settings["target_modules"], model_settings["rank"], ) convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) def main(): fire.Fire(convert) if __name__ == "__main__": main()