| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | PEFT utilities: Utilities related to peft library |
| | """ |
| |
|
| | import collections |
| | import importlib |
| | from typing import Optional |
| |
|
| | from packaging import version |
| |
|
| | from .import_utils import is_peft_available, is_torch_available |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| |
|
| | def recurse_remove_peft_layers(model): |
| | r""" |
| | Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. |
| | """ |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | has_base_layer_pattern = False |
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | has_base_layer_pattern = hasattr(module, "base_layer") |
| | break |
| |
|
| | if has_base_layer_pattern: |
| | from peft.utils import _get_submodules |
| |
|
| | key_list = [key for key, _ in model.named_modules() if "lora" not in key] |
| | for key in key_list: |
| | try: |
| | parent, target, target_name = _get_submodules(model, key) |
| | except AttributeError: |
| | continue |
| | if hasattr(target, "base_layer"): |
| | setattr(parent, target_name, target.get_base_layer()) |
| | else: |
| | |
| | |
| | from peft.tuners.lora import LoraLayer |
| |
|
| | for name, module in model.named_children(): |
| | if len(list(module.children())) > 0: |
| | |
| | recurse_remove_peft_layers(module) |
| |
|
| | module_replaced = False |
| |
|
| | if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): |
| | new_module = torch.nn.Linear( |
| | module.in_features, |
| | module.out_features, |
| | bias=module.bias is not None, |
| | ).to(module.weight.device) |
| | new_module.weight = module.weight |
| | if module.bias is not None: |
| | new_module.bias = module.bias |
| |
|
| | module_replaced = True |
| | elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): |
| | new_module = torch.nn.Conv2d( |
| | module.in_channels, |
| | module.out_channels, |
| | module.kernel_size, |
| | module.stride, |
| | module.padding, |
| | module.dilation, |
| | module.groups, |
| | ).to(module.weight.device) |
| |
|
| | new_module.weight = module.weight |
| | if module.bias is not None: |
| | new_module.bias = module.bias |
| |
|
| | module_replaced = True |
| |
|
| | if module_replaced: |
| | setattr(model, name, new_module) |
| | del module |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | return model |
| |
|
| |
|
| | def scale_lora_layers(model, weight): |
| | """ |
| | Adjust the weightage given to the LoRA layers of the model. |
| | |
| | Args: |
| | model (`torch.nn.Module`): |
| | The model to scale. |
| | weight (`float`): |
| | The weight to be given to the LoRA layers. |
| | """ |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | if weight == 1.0: |
| | return |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | module.scale_layer(weight) |
| |
|
| |
|
| | def unscale_lora_layers(model, weight: Optional[float] = None): |
| | """ |
| | Removes the previously passed weight given to the LoRA layers of the model. |
| | |
| | Args: |
| | model (`torch.nn.Module`): |
| | The model to scale. |
| | weight (`float`, *optional*): |
| | The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be |
| | re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct |
| | value. |
| | """ |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | if weight == 1.0: |
| | return |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | if weight is not None and weight != 0: |
| | module.unscale_layer(weight) |
| | elif weight is not None and weight == 0: |
| | for adapter_name in module.active_adapters: |
| | |
| | module.set_scale(adapter_name, 1.0) |
| |
|
| |
|
| | def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): |
| | rank_pattern = {} |
| | alpha_pattern = {} |
| | r = lora_alpha = list(rank_dict.values())[0] |
| |
|
| | if len(set(rank_dict.values())) > 1: |
| | |
| | r = collections.Counter(rank_dict.values()).most_common()[0][0] |
| |
|
| | |
| | rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) |
| | rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} |
| |
|
| | if network_alpha_dict is not None and len(network_alpha_dict) > 0: |
| | if len(set(network_alpha_dict.values())) > 1: |
| | |
| | lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] |
| |
|
| | |
| | alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) |
| | if is_unet: |
| | alpha_pattern = { |
| | ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v |
| | for k, v in alpha_pattern.items() |
| | } |
| | else: |
| | alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} |
| | else: |
| | lora_alpha = set(network_alpha_dict.values()).pop() |
| |
|
| | |
| | target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) |
| | use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict) |
| |
|
| | lora_config_kwargs = { |
| | "r": r, |
| | "lora_alpha": lora_alpha, |
| | "rank_pattern": rank_pattern, |
| | "alpha_pattern": alpha_pattern, |
| | "target_modules": target_modules, |
| | "use_dora": use_dora, |
| | } |
| | return lora_config_kwargs |
| |
|
| |
|
| | def get_adapter_name(model): |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | return f"default_{len(module.r)}" |
| | return "default_0" |
| |
|
| |
|
| | def set_adapter_layers(model, enabled=True): |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | |
| | if hasattr(module, "enable_adapters"): |
| | module.enable_adapters(enabled=enabled) |
| | else: |
| | module.disable_adapters = not enabled |
| |
|
| |
|
| | def delete_adapter_layers(model, adapter_name): |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | if hasattr(module, "delete_adapter"): |
| | module.delete_adapter(adapter_name) |
| | else: |
| | raise ValueError( |
| | "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" |
| | ) |
| |
|
| | |
| | if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"): |
| | model.peft_config.pop(adapter_name, None) |
| | |
| | |
| | if len(model.peft_config) == 0: |
| | del model.peft_config |
| | model._hf_peft_config_loaded = None |
| |
|
| |
|
| | def set_weights_and_activate_adapters(model, adapter_names, weights): |
| | from peft.tuners.tuners_utils import BaseTunerLayer |
| |
|
| | def get_module_weight(weight_for_adapter, module_name): |
| | if not isinstance(weight_for_adapter, dict): |
| | |
| | return weight_for_adapter |
| |
|
| | for layer_name, weight_ in weight_for_adapter.items(): |
| | if layer_name in module_name: |
| | return weight_ |
| |
|
| | parts = module_name.split(".") |
| | |
| | key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}" |
| | block_weight = weight_for_adapter.get(key, 1.0) |
| |
|
| | return block_weight |
| |
|
| | |
| | for adapter_name, weight in zip(adapter_names, weights): |
| | for module_name, module in model.named_modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | |
| | if hasattr(module, "set_adapter"): |
| | module.set_adapter(adapter_name) |
| | else: |
| | module.active_adapter = adapter_name |
| | module.set_scale(adapter_name, get_module_weight(weight, module_name)) |
| |
|
| | |
| | for module in model.modules(): |
| | if isinstance(module, BaseTunerLayer): |
| | |
| | if hasattr(module, "set_adapter"): |
| | module.set_adapter(adapter_names) |
| | else: |
| | module.active_adapter = adapter_names |
| |
|
| |
|
| | def check_peft_version(min_version: str) -> None: |
| | r""" |
| | Checks if the version of PEFT is compatible. |
| | |
| | Args: |
| | version (`str`): |
| | The version of PEFT to check against. |
| | """ |
| | if not is_peft_available(): |
| | raise ValueError("PEFT is not installed. Please install it with `pip install peft`") |
| |
|
| | is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version) |
| |
|
| | if not is_peft_version_compatible: |
| | raise ValueError( |
| | f"The version of PEFT you are using is not compatible, please use a version that is greater" |
| | f" than {min_version}" |
| | ) |
| |
|