Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, List | |
import torch.nn as nn | |
from peft.utils import _freeze_adapter, _get_submodules | |
from .config import AdaptionPromptConfig, prepare_config | |
from .layer import AdaptedAttention | |
from .utils import is_adaption_prompt_trainable | |
class AdaptionPromptModel(nn.Module): | |
""" | |
Implements adaption prompts as described in https://arxiv.org/pdf/2303.16199.pdf. | |
The top L attention modules are replaced with AdaptedAttention modules that wrap the original ones, but insert | |
trainable prompts with gates (for zero init). | |
Notes on the multi-adapter pattern: | |
- We store the states of different adapters by keeping a dictionary of AdaptedAttention modules indexed by adapter | |
name. | |
- Every time we switch adapters, we remove the modules of the currently active adapter from the model, store them | |
in the dictionary, and replace them with the modules of the new adapter. | |
- To avoid duplicated and potentially inconsistent state, the currently active adapter is always removed from the | |
dictionary. | |
- Disabling the adapter would also result in the modules being removed from the model. | |
""" | |
def __init__(self, model, configs: Dict, adapter_name: str): | |
super().__init__() | |
self.model = model | |
# Store adapter configs by name. | |
self.peft_config: Dict[str, AdaptionPromptConfig] = {} | |
# Store lists of the parents of the affected attention modules by adapter name. | |
# We keep references to the parents so we can swap the adapters in-and-out of the model. | |
self._parents: Dict[str, List[nn.Module]] = {} | |
# Store lists of cached AdaptedAttention modules by name. | |
self._cached_adapters: Dict[str, List] = {} | |
# The name of the currently active adapter. | |
self._active_adapter = None | |
# Whether the adapter is enabled. | |
self._enabled = True | |
self.forward = self.model.forward | |
self.add_adapter(adapter_name, configs[adapter_name]) | |
self._mark_only_adaption_prompts_as_trainable(self.model) | |
def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None: | |
"""Add an adapter with the given name and config.""" | |
config = prepare_config(config, self.model) | |
if adapter_name in self.peft_config: | |
raise ValueError(f"Adapter with name '{adapter_name}' already exists.") | |
parents = [] | |
for name, _ in self.model.named_modules(): | |
if name.endswith(config.target_modules): | |
par, _, _ = _get_submodules(self.model, name) | |
parents.append(par) | |
if len(parents) < config.adapter_layers: | |
raise ValueError( | |
f"Config specifies more adapter layers '{config.adapter_layers}'" | |
f" than the model has '{len(parents)}'." | |
) | |
# Note that if the target modules are not in Sequential, ModuleList, or | |
# some other PyTorch ordered container, the behavior is undefined as we | |
# assume here that the order of the modules is the same as the order of | |
# the transformer decoder layers. | |
parents = parents[-config.adapter_layers :] | |
self._parents[adapter_name] = parents | |
# It is only None during initialization. | |
# If it is disabled, we don't have to remove the modules. | |
if self._active_adapter is not None and self._enabled: | |
self._remove_adapted_attentions(self._active_adapter) | |
self._active_adapter = adapter_name | |
self.peft_config[adapter_name] = config | |
self._create_adapted_attentions(config, parents) | |
if not self._enabled: | |
self._remove_adapted_attentions(self._active_adapter) | |
if config.inference_mode: | |
_freeze_adapter(self.model, adapter_name) | |
def set_adapter(self, adapter_name: str) -> None: | |
"""Set the model to use the adapter with the given name.""" | |
if self._active_adapter == adapter_name: | |
return | |
if adapter_name not in self.peft_config: | |
raise ValueError(f"Adapter with name '{adapter_name}' does not exist.") | |
if self._enabled: | |
self._remove_adapted_attentions(self._active_adapter) | |
self._set_adapted_attentions(adapter_name) | |
self._active_adapter = adapter_name | |
def enable_adapter_layers(self): | |
"""Enable adapter layers by swapping in cached AdaptedAttention modules.""" | |
self._enabled = True | |
self._set_adapted_attentions(self._active_adapter) | |
def disable_adapter_layers(self): | |
"""Disable adapter layers by swapping out AdaptedAttention modules.""" | |
self._enabled = False | |
self._remove_adapted_attentions(self._active_adapter) | |
def _create_adapted_attentions(self, config: AdaptionPromptConfig, parents: List[nn.Module]) -> None: | |
"""Wrap LlamaAttention modules with newly created AdaptedAttention modules.""" | |
for par in parents: | |
attn = AdaptedAttention( | |
model_type=self.model.config.model_type, | |
adapter_len=config.adapter_len, | |
model=getattr(par, config.target_modules), | |
) | |
setattr(par, config.target_modules, attn) | |
def _set_adapted_attentions(self, adapter_name: str) -> None: | |
"""Replace LlamaAttention modules with cached AdaptedAttention modules.""" | |
cached = self._cached_adapters[adapter_name] | |
del self._cached_adapters[adapter_name] | |
config = self.peft_config[adapter_name] | |
for i, par in enumerate(self._parents[adapter_name]): | |
setattr(par, config.target_modules, cached[i]) | |
def _remove_adapted_attentions(self, adapter_name: str) -> None: | |
"""Remove AdaptedAttention modules from the model and store them in the cache.""" | |
config = self.peft_config[adapter_name] | |
adapted_attentions = [] | |
for par in self._parents[adapter_name]: | |
attn = getattr(par, config.target_modules) | |
adapted_attentions.append(attn) | |
setattr(par, config.target_modules, attn.model) | |
self._cached_adapters[adapter_name] = adapted_attentions | |
def _mark_only_adaption_prompts_as_trainable(self, model: nn.Module) -> None: | |
"""Freeze all parameters of the model except the adaption prompts.""" | |
for n, p in model.named_parameters(): | |
if not is_adaption_prompt_trainable(n): | |
p.requires_grad = False | |
def __getattr__(self, name: str): | |
"""Forward missing attributes to the wrapped module.""" | |
try: | |
return super().__getattr__(name) # defer to nn.Module's logic | |
except AttributeError: | |
# This is necessary as e.g. causal models have various methods that we | |
# don't want to re-implement here. | |
return getattr(self.model, name) | |