import glob from os import path from paths import get_file_name, FastStableDiffusionPaths from pathlib import Path # A basic class to keep track of the currently loaded LoRAs and # their weights; the diffusers funtion \c get_active_adapters() # returns a list of adapter names but not their weights so we need # a way to keep track of the current LoRA weights to set whenever # a new LoRA is loaded class _lora_info: def __init__( self, path: str, weight: float, ): self.path = path self.adapter_name = get_file_name(path) self.weight = weight def __del__(self): self.path = None self.adapter_name = None _loaded_loras = [] _current_pipeline = None # This function loads a LoRA from the LoRA path setting, so it's # possible to load multiple LoRAs by calling this function more than # once with a different LoRA path setting; note that if you plan to # load multiple LoRAs and dynamically change their weights, you # might want to set the LoRA fuse option to False def load_lora_weight( pipeline, lcm_diffusion_setting, ): if not lcm_diffusion_setting.lora.path: raise Exception("Empty lora model path") if not path.exists(lcm_diffusion_setting.lora.path): raise Exception("Lora model path is invalid") # If the pipeline has been rebuilt since the last call, remove all # references to previously loaded LoRAs and store the new pipeline global _loaded_loras global _current_pipeline if pipeline != _current_pipeline: for lora in _loaded_loras: del lora del _loaded_loras _loaded_loras = [] _current_pipeline = pipeline current_lora = _lora_info( lcm_diffusion_setting.lora.path, lcm_diffusion_setting.lora.weight, ) _loaded_loras.append(current_lora) if lcm_diffusion_setting.lora.enabled: print(f"LoRA adapter name : {current_lora.adapter_name}") pipeline.load_lora_weights( FastStableDiffusionPaths.get_lora_models_path(), weight_name=Path(lcm_diffusion_setting.lora.path).name, local_files_only=True, adapter_name=current_lora.adapter_name, ) update_lora_weights( pipeline, lcm_diffusion_setting, ) if lcm_diffusion_setting.lora.fuse: pipeline.fuse_lora() def get_lora_models(root_dir: str): lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True) lora_models_map = {} for file_path in lora_models: lora_name = get_file_name(file_path) if lora_name is not None: lora_models_map[lora_name] = file_path return lora_models_map # This function returns a list of (adapter_name, weight) tuples for the # currently loaded LoRAs def get_active_lora_weights(): active_loras = [] for lora_info in _loaded_loras: active_loras.append( ( lora_info.adapter_name, lora_info.weight, ) ) return active_loras # This function receives a pipeline, an lcm_diffusion_setting object and # an optional list of updated (adapter_name, weight) tuples def update_lora_weights( pipeline, lcm_diffusion_setting, lora_weights=None, ): global _loaded_loras global _current_pipeline if pipeline != _current_pipeline: print("Wrong pipeline when trying to update LoRA weights") return if lora_weights: for idx, lora in enumerate(lora_weights): if _loaded_loras[idx].adapter_name != lora[0]: print("Wrong adapter name in LoRA enumeration!") continue _loaded_loras[idx].weight = lora[1] adapter_names = [] adapter_weights = [] if lcm_diffusion_setting.use_lcm_lora: adapter_names.append("lcm") adapter_weights.append(1.0) for lora in _loaded_loras: adapter_names.append(lora.adapter_name) adapter_weights.append(lora.weight) pipeline.set_adapters( adapter_names, adapter_weights=adapter_weights, ) adapater_weights = zip(adapter_names, adapter_weights) print(f"Adapters: {list(adapater_weights)}")