Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .config import TRANSFORMERS_MODEL_CONFIG | |
class AdaptedAttention(nn.Module): | |
"""This module wraps a LLamaAttention module and injects adaption prompts.""" | |
def __init__(self, model_type: str, adapter_len: int, model): | |
""" | |
Initialize object. | |
Args: | |
model_type: The transformer model type. This is used to retrieve the right method to | |
compute query states. | |
adapter_len: The length of the adaption prompt to insert. | |
model: The original transformer attention module that is being wrapped. | |
""" | |
assert not isinstance(model, AdaptedAttention) | |
super().__init__() | |
self.model_type = model_type | |
self.model = model | |
self.adapter_len = adapter_len | |
# Assume all parameters of the attention model we are wrapping are on the same device. | |
device = next(model.parameters()).device | |
# Don't think this was specified in the paper, but we follow the official repo which used an Embedding | |
# which initializes the tokens with standard normal values. | |
# https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234 | |
# (bsz, adapter_len, hidden_size) | |
target_dtype = ( | |
model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32 | |
) | |
self.adaption_prompt = nn.Parameter( | |
torch.empty(1, adapter_len, self.model.hidden_size, device=device, dtype=target_dtype).normal_() | |
) | |
# Initialize the gate to 0 as this is "zero-init". | |
self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype)) | |
def forward(self, **kwargs): | |
""" | |
Forward pass for the adapter which wraps the original LlamaAttention module. | |
"Official" paper implementation: | |
https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L141 | |
Args: | |
kwargs: See the original LlamaAttention module. | |
""" | |
if kwargs.get("output_attention", False): | |
raise NotImplementedError("output_attention is not currently supported.") | |
output, _, past_key_value = self.model(**kwargs) | |
bsz = output.shape[0] | |
q_len = output.shape[1] | |
embed_dim = output.shape[2] | |
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer | |
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer | |
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer | |
factor = ( | |
self.model.k_proj.in_features // self.model.k_proj.out_features | |
) # Mistral has different input and output dimension for k_proj and v_proj layers | |
if k_proj_layer == v_proj_layer: | |
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2) | |
else: | |
key = getattr(self.model, k_proj_layer)(self.adaption_prompt) | |
value = getattr(self.model, v_proj_layer)(self.adaption_prompt) | |
# (bsz, num_key_value_heads, adapter_len, head_dim) | |
adapter_k = ( | |
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) | |
.repeat(bsz, 1, 1, 1) | |
.transpose(1, 2) | |
) | |
adapter_v = ( | |
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) | |
.repeat(bsz, 1, 1, 1) | |
.transpose(1, 2) | |
) | |
# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 | |
# (bsz, num_heads, adapter_len, head_dim) | |
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) | |
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) | |
# Recompute query states. | |
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states | |
# (bsz, num_heads, q_len, head_dim) | |
query_states = compute_query_states(model=self.model, **kwargs) | |
previous_dtype = query_states.dtype | |
# (bsz, num_heads, q_len, adapter_len) | |
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( | |
self.model.head_dim | |
) | |
# Upcast attention to fp32 | |
# (bsz, num_heads, q_len, adapter_len) | |
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype) | |
# (bsz, q_len, num_heads * head_dim) | |
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) | |
# (bsz, q_len, hidden_size) | |
if o_proj_layer is not None: | |
adapter_output = getattr(self.model, o_proj_layer)(adapter_output) | |
# Add adaption prompt output to original output. | |
output = output + adapter_output | |
# Restore original dtype. | |
output = output.to(previous_dtype) | |
return output, None, past_key_value | |