|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | State dict utilities: utility methods for converting state dicts easily | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import enum | 
					
						
						|  |  | 
					
						
						|  | from .logging import get_logger | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger = get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class StateDictType(enum.Enum): | 
					
						
						|  | """ | 
					
						
						|  | The mode to use when converting state dicts. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | DIFFUSERS_OLD = "diffusers_old" | 
					
						
						|  | KOHYA_SS = "kohya_ss" | 
					
						
						|  | PEFT = "peft" | 
					
						
						|  | DIFFUSERS = "diffusers" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | UNET_TO_DIFFUSERS = { | 
					
						
						|  | ".to_out_lora.up": ".to_out.0.lora_B", | 
					
						
						|  | ".to_out_lora.down": ".to_out.0.lora_A", | 
					
						
						|  | ".to_q_lora.down": ".to_q.lora_A", | 
					
						
						|  | ".to_q_lora.up": ".to_q.lora_B", | 
					
						
						|  | ".to_k_lora.down": ".to_k.lora_A", | 
					
						
						|  | ".to_k_lora.up": ".to_k.lora_B", | 
					
						
						|  | ".to_v_lora.down": ".to_v.lora_A", | 
					
						
						|  | ".to_v_lora.up": ".to_v.lora_B", | 
					
						
						|  | ".lora.up": ".lora_B", | 
					
						
						|  | ".lora.down": ".lora_A", | 
					
						
						|  | ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DIFFUSERS_TO_PEFT = { | 
					
						
						|  | ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", | 
					
						
						|  | ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", | 
					
						
						|  | ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", | 
					
						
						|  | ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", | 
					
						
						|  | ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", | 
					
						
						|  | ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", | 
					
						
						|  | ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", | 
					
						
						|  | ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", | 
					
						
						|  | ".lora_linear_layer.up": ".lora_B", | 
					
						
						|  | ".lora_linear_layer.down": ".lora_A", | 
					
						
						|  | "text_projection.lora.down.weight": "text_projection.lora_A.weight", | 
					
						
						|  | "text_projection.lora.up.weight": "text_projection.lora_B.weight", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | DIFFUSERS_OLD_TO_PEFT = { | 
					
						
						|  | ".to_q_lora.up": ".q_proj.lora_B", | 
					
						
						|  | ".to_q_lora.down": ".q_proj.lora_A", | 
					
						
						|  | ".to_k_lora.up": ".k_proj.lora_B", | 
					
						
						|  | ".to_k_lora.down": ".k_proj.lora_A", | 
					
						
						|  | ".to_v_lora.up": ".v_proj.lora_B", | 
					
						
						|  | ".to_v_lora.down": ".v_proj.lora_A", | 
					
						
						|  | ".to_out_lora.up": ".out_proj.lora_B", | 
					
						
						|  | ".to_out_lora.down": ".out_proj.lora_A", | 
					
						
						|  | ".lora_linear_layer.up": ".lora_B", | 
					
						
						|  | ".lora_linear_layer.down": ".lora_A", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | PEFT_TO_DIFFUSERS = { | 
					
						
						|  | ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", | 
					
						
						|  | ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", | 
					
						
						|  | ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", | 
					
						
						|  | ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", | 
					
						
						|  | ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", | 
					
						
						|  | ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", | 
					
						
						|  | ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", | 
					
						
						|  | ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", | 
					
						
						|  | "to_k.lora_A": "to_k.lora.down", | 
					
						
						|  | "to_k.lora_B": "to_k.lora.up", | 
					
						
						|  | "to_q.lora_A": "to_q.lora.down", | 
					
						
						|  | "to_q.lora_B": "to_q.lora.up", | 
					
						
						|  | "to_v.lora_A": "to_v.lora.down", | 
					
						
						|  | "to_v.lora_B": "to_v.lora.up", | 
					
						
						|  | "to_out.0.lora_A": "to_out.0.lora.down", | 
					
						
						|  | "to_out.0.lora_B": "to_out.0.lora.up", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | DIFFUSERS_OLD_TO_DIFFUSERS = { | 
					
						
						|  | ".to_q_lora.up": ".q_proj.lora_linear_layer.up", | 
					
						
						|  | ".to_q_lora.down": ".q_proj.lora_linear_layer.down", | 
					
						
						|  | ".to_k_lora.up": ".k_proj.lora_linear_layer.up", | 
					
						
						|  | ".to_k_lora.down": ".k_proj.lora_linear_layer.down", | 
					
						
						|  | ".to_v_lora.up": ".v_proj.lora_linear_layer.up", | 
					
						
						|  | ".to_v_lora.down": ".v_proj.lora_linear_layer.down", | 
					
						
						|  | ".to_out_lora.up": ".out_proj.lora_linear_layer.up", | 
					
						
						|  | ".to_out_lora.down": ".out_proj.lora_linear_layer.down", | 
					
						
						|  | ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", | 
					
						
						|  | ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", | 
					
						
						|  | ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", | 
					
						
						|  | ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | PEFT_TO_KOHYA_SS = { | 
					
						
						|  | "lora_A": "lora_down", | 
					
						
						|  | "lora_B": "lora_up", | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | PEFT_STATE_DICT_MAPPINGS = { | 
					
						
						|  | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, | 
					
						
						|  | StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | DIFFUSERS_STATE_DICT_MAPPINGS = { | 
					
						
						|  | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, | 
					
						
						|  | StateDictType.PEFT: PEFT_TO_DIFFUSERS, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} | 
					
						
						|  |  | 
					
						
						|  | KEYS_TO_ALWAYS_REPLACE = { | 
					
						
						|  | ".processor.": ".", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_state_dict(state_dict, mapping): | 
					
						
						|  | r""" | 
					
						
						|  | Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | state_dict (`dict[str, torch.Tensor]`): | 
					
						
						|  | The state dict to convert. | 
					
						
						|  | mapping (`dict[str, str]`): | 
					
						
						|  | The mapping to use for conversion, the mapping should be a dictionary with the following structure: | 
					
						
						|  | - key: the pattern to replace | 
					
						
						|  | - value: the pattern to replace with | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | converted_state_dict (`dict`) | 
					
						
						|  | The converted state dict. | 
					
						
						|  | """ | 
					
						
						|  | converted_state_dict = {} | 
					
						
						|  | for k, v in state_dict.items(): | 
					
						
						|  |  | 
					
						
						|  | for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): | 
					
						
						|  | if pattern in k: | 
					
						
						|  | new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] | 
					
						
						|  | k = k.replace(pattern, new_pattern) | 
					
						
						|  |  | 
					
						
						|  | for pattern in mapping.keys(): | 
					
						
						|  | if pattern in k: | 
					
						
						|  | new_pattern = mapping[pattern] | 
					
						
						|  | k = k.replace(pattern, new_pattern) | 
					
						
						|  | break | 
					
						
						|  | converted_state_dict[k] = v | 
					
						
						|  | return converted_state_dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): | 
					
						
						|  | r""" | 
					
						
						|  | Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or | 
					
						
						|  | new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | state_dict (`dict[str, torch.Tensor]`): | 
					
						
						|  | The state dict to convert. | 
					
						
						|  | original_type (`StateDictType`, *optional*): | 
					
						
						|  | The original type of the state dict, if not provided, the method will try to infer it automatically. | 
					
						
						|  | """ | 
					
						
						|  | if original_type is None: | 
					
						
						|  |  | 
					
						
						|  | if any("to_out_lora" in k for k in state_dict.keys()): | 
					
						
						|  | original_type = StateDictType.DIFFUSERS_OLD | 
					
						
						|  | elif any("lora_linear_layer" in k for k in state_dict.keys()): | 
					
						
						|  | original_type = StateDictType.DIFFUSERS | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("Could not automatically infer state dict type") | 
					
						
						|  |  | 
					
						
						|  | if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): | 
					
						
						|  | raise ValueError(f"Original type {original_type} is not supported") | 
					
						
						|  |  | 
					
						
						|  | mapping = PEFT_STATE_DICT_MAPPINGS[original_type] | 
					
						
						|  | return convert_state_dict(state_dict, mapping) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): | 
					
						
						|  | r""" | 
					
						
						|  | Converts a state dict to new diffusers format. The state dict can be from previous diffusers format | 
					
						
						|  | (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will | 
					
						
						|  | return the state dict as is. | 
					
						
						|  |  | 
					
						
						|  | The method only supports the conversion from diffusers old, PEFT to diffusers new for now. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | state_dict (`dict[str, torch.Tensor]`): | 
					
						
						|  | The state dict to convert. | 
					
						
						|  | original_type (`StateDictType`, *optional*): | 
					
						
						|  | The original type of the state dict, if not provided, the method will try to infer it automatically. | 
					
						
						|  | kwargs (`dict`, *args*): | 
					
						
						|  | Additional arguments to pass to the method. | 
					
						
						|  |  | 
					
						
						|  | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended | 
					
						
						|  | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | 
					
						
						|  | `get_peft_model_state_dict` method: | 
					
						
						|  | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | 
					
						
						|  | but we add it here in case we don't want to rely on that method. | 
					
						
						|  | """ | 
					
						
						|  | peft_adapter_name = kwargs.pop("adapter_name", None) | 
					
						
						|  | if peft_adapter_name is not None: | 
					
						
						|  | peft_adapter_name = "." + peft_adapter_name | 
					
						
						|  | else: | 
					
						
						|  | peft_adapter_name = "" | 
					
						
						|  |  | 
					
						
						|  | if original_type is None: | 
					
						
						|  |  | 
					
						
						|  | if any("to_out_lora" in k for k in state_dict.keys()): | 
					
						
						|  | original_type = StateDictType.DIFFUSERS_OLD | 
					
						
						|  | elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | 
					
						
						|  | original_type = StateDictType.PEFT | 
					
						
						|  | elif any("lora_linear_layer" in k for k in state_dict.keys()): | 
					
						
						|  |  | 
					
						
						|  | return state_dict | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("Could not automatically infer state dict type") | 
					
						
						|  |  | 
					
						
						|  | if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): | 
					
						
						|  | raise ValueError(f"Original type {original_type} is not supported") | 
					
						
						|  |  | 
					
						
						|  | mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] | 
					
						
						|  | return convert_state_dict(state_dict, mapping) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_unet_state_dict_to_peft(state_dict): | 
					
						
						|  | r""" | 
					
						
						|  | Converts a state dict from UNet format to diffusers format - i.e. by removing some keys | 
					
						
						|  | """ | 
					
						
						|  | mapping = UNET_TO_DIFFUSERS | 
					
						
						|  | return convert_state_dict(state_dict, mapping) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_all_state_dict_to_peft(state_dict): | 
					
						
						|  | r""" | 
					
						
						|  | Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid | 
					
						
						|  | `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | peft_dict = convert_state_dict_to_peft(state_dict) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | if str(e) == "Could not automatically infer state dict type": | 
					
						
						|  | peft_dict = convert_unet_state_dict_to_peft(state_dict) | 
					
						
						|  | else: | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): | 
					
						
						|  | raise ValueError("Your LoRA was not converted to PEFT") | 
					
						
						|  |  | 
					
						
						|  | return peft_dict | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): | 
					
						
						|  | r""" | 
					
						
						|  | Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. | 
					
						
						|  | The method only supports the conversion from PEFT to Kohya for now. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | state_dict (`dict[str, torch.Tensor]`): | 
					
						
						|  | The state dict to convert. | 
					
						
						|  | original_type (`StateDictType`, *optional*): | 
					
						
						|  | The original type of the state dict, if not provided, the method will try to infer it automatically. | 
					
						
						|  | kwargs (`dict`, *args*): | 
					
						
						|  | Additional arguments to pass to the method. | 
					
						
						|  |  | 
					
						
						|  | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended | 
					
						
						|  | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | 
					
						
						|  | `get_peft_model_state_dict` method: | 
					
						
						|  | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | 
					
						
						|  | but we add it here in case we don't want to rely on that method. | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  | import torch | 
					
						
						|  | except ImportError: | 
					
						
						|  | logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | peft_adapter_name = kwargs.pop("adapter_name", None) | 
					
						
						|  | if peft_adapter_name is not None: | 
					
						
						|  | peft_adapter_name = "." + peft_adapter_name | 
					
						
						|  | else: | 
					
						
						|  | peft_adapter_name = "" | 
					
						
						|  |  | 
					
						
						|  | if original_type is None: | 
					
						
						|  | if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | 
					
						
						|  | original_type = StateDictType.PEFT | 
					
						
						|  |  | 
					
						
						|  | if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): | 
					
						
						|  | raise ValueError(f"Original type {original_type} is not supported") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) | 
					
						
						|  | kohya_ss_state_dict = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for kohya_key, weight in kohya_ss_partial_state_dict.items(): | 
					
						
						|  | if "text_encoder_2." in kohya_key: | 
					
						
						|  | kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") | 
					
						
						|  | elif "text_encoder." in kohya_key: | 
					
						
						|  | kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") | 
					
						
						|  | elif "unet" in kohya_key: | 
					
						
						|  | kohya_key = kohya_key.replace("unet", "lora_unet") | 
					
						
						|  | elif "lora_magnitude_vector" in kohya_key: | 
					
						
						|  | kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale") | 
					
						
						|  |  | 
					
						
						|  | kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) | 
					
						
						|  | kohya_key = kohya_key.replace(peft_adapter_name, "") | 
					
						
						|  | kohya_ss_state_dict[kohya_key] = weight | 
					
						
						|  | if "lora_down" in kohya_key: | 
					
						
						|  | alpha_key = f'{kohya_key.split(".")[0]}.alpha' | 
					
						
						|  | kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) | 
					
						
						|  |  | 
					
						
						|  | return kohya_ss_state_dict | 
					
						
						|  |  |