Spaces:
Runtime error
Runtime error
File size: 4,250 Bytes
510ee71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import glob
from os import path
from paths import get_file_name, FastStableDiffusionPaths
from pathlib import Path
# A basic class to keep track of the currently loaded LoRAs and
# their weights; the diffusers function \c get_active_adapters()
# returns a list of adapter names but not their weights so we need
# a way to keep track of the current LoRA weights to set whenever
# a new LoRA is loaded
class _lora_info:
def __init__(
self,
path: str,
weight: float,
):
self.path = path
self.adapter_name = get_file_name(path)
self.weight = weight
def __del__(self):
self.path = None
self.adapter_name = None
_loaded_loras = []
_current_pipeline = None
# This function loads a LoRA from the LoRA path setting, so it's
# possible to load multiple LoRAs by calling this function more than
# once with a different LoRA path setting; note that if you plan to
# load multiple LoRAs and dynamically change their weights, you
# might want to set the LoRA fuse option to False
def load_lora_weight(
pipeline,
lcm_diffusion_setting,
):
if not lcm_diffusion_setting.lora.path:
raise Exception("Empty lora model path")
if not path.exists(lcm_diffusion_setting.lora.path):
raise Exception("Lora model path is invalid")
# If the pipeline has been rebuilt since the last call, remove all
# references to previously loaded LoRAs and store the new pipeline
global _loaded_loras
global _current_pipeline
if pipeline != _current_pipeline:
for lora in _loaded_loras:
del lora
del _loaded_loras
_loaded_loras = []
_current_pipeline = pipeline
current_lora = _lora_info(
lcm_diffusion_setting.lora.path,
lcm_diffusion_setting.lora.weight,
)
_loaded_loras.append(current_lora)
if lcm_diffusion_setting.lora.enabled:
print(f"LoRA adapter name : {current_lora.adapter_name}")
pipeline.load_lora_weights(
FastStableDiffusionPaths.get_lora_models_path(),
weight_name=Path(lcm_diffusion_setting.lora.path).name,
local_files_only=True,
adapter_name=current_lora.adapter_name,
)
update_lora_weights(
pipeline,
lcm_diffusion_setting,
)
if lcm_diffusion_setting.lora.fuse:
pipeline.fuse_lora()
def get_lora_models(root_dir: str):
lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
lora_models_map = {}
for file_path in lora_models:
lora_name = get_file_name(file_path)
if lora_name is not None:
lora_models_map[lora_name] = file_path
return lora_models_map
# This function returns a list of (adapter_name, weight) tuples for the
# currently loaded LoRAs
def get_active_lora_weights():
active_loras = []
for lora_info in _loaded_loras:
active_loras.append(
(
lora_info.adapter_name,
lora_info.weight,
)
)
return active_loras
# This function receives a pipeline, an lcm_diffusion_setting object and
# an optional list of updated (adapter_name, weight) tuples
def update_lora_weights(
pipeline,
lcm_diffusion_setting,
lora_weights=None,
):
global _loaded_loras
global _current_pipeline
if pipeline != _current_pipeline:
print("Wrong pipeline when trying to update LoRA weights")
return
if lora_weights:
for idx, lora in enumerate(lora_weights):
if _loaded_loras[idx].adapter_name != lora[0]:
print("Wrong adapter name in LoRA enumeration!")
continue
_loaded_loras[idx].weight = lora[1]
adapter_names = []
adapter_weights = []
if lcm_diffusion_setting.use_lcm_lora:
adapter_names.append("lcm")
adapter_weights.append(1.0)
for lora in _loaded_loras:
adapter_names.append(lora.adapter_name)
adapter_weights.append(lora.weight)
pipeline.set_adapters(
adapter_names,
adapter_weights=adapter_weights,
)
adapter_weights = zip(adapter_names, adapter_weights)
print(f"Adapters: {list(adapter_weights)}")
|