Spaces:
Runtime error
Runtime error
import os | |
import fire | |
import torch | |
from lora_diffusion import ( | |
DEFAULT_TARGET_REPLACE, | |
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | |
UNET_DEFAULT_TARGET_REPLACE, | |
convert_loras_to_safeloras_with_embeds, | |
safetensors_available, | |
) | |
_target_by_name = { | |
"unet": UNET_DEFAULT_TARGET_REPLACE, | |
"text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | |
} | |
def convert(*paths, outpath, overwrite=False, **settings): | |
""" | |
Converts one or more pytorch Lora and/or Textual Embedding pytorch files | |
into a safetensor file. | |
Pass all the input paths as arguments. Whether they are Textual Embedding | |
or Lora models will be auto-detected. | |
For Lora models, their name will be taken from the path, i.e. | |
"lora_weight.pt" => unet | |
"lora_weight.text_encoder.pt" => text_encoder | |
You can also set target_modules and/or rank by providing an argument prefixed | |
by the name. | |
So a complete example might be something like: | |
``` | |
python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 | |
``` | |
""" | |
modelmap = {} | |
embeds = {} | |
if os.path.exists(outpath) and not overwrite: | |
raise ValueError( | |
f"Output path {outpath} already exists, and overwrite is not True" | |
) | |
for path in paths: | |
data = torch.load(path) | |
if isinstance(data, dict): | |
print(f"Loading textual inversion embeds {data.keys()} from {path}") | |
embeds.update(data) | |
else: | |
name_parts = os.path.split(path)[1].split(".") | |
name = name_parts[-2] if len(name_parts) > 2 else "unet" | |
model_settings = { | |
"target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), | |
"rank": 4, | |
} | |
prefix = f"{name}." | |
arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) } | |
model_settings = { **model_settings, **arg_settings } | |
print(f"Loading Lora for {name} from {path} with settings {model_settings}") | |
modelmap[name] = ( | |
path, | |
model_settings["target_modules"], | |
model_settings["rank"], | |
) | |
convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) | |
def main(): | |
fire.Fire(convert) | |
if __name__ == "__main__": | |
main() | |