File size: 15,132 Bytes
7865a06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
"""PyTorch OpenAI GPT-2 model modified with MultiQuery attention"""
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel, SequenceSummary
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from transformers.utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2Block, GPT2PreTrainedModel, GPT2LMHeadModel
from .configuration_gpt2_mq import GPT2CustomConfig, MULTI_QUERY, MULTI_HEAD
class GPT2MQAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__()
assert config.attention_head_type == MULTI_QUERY
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
if is_cross_attention:
raise NotImplementedError("Cross-attention not implemented for MQA")
self.is_cross_attention = is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
# self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
# Keys and values are shared across heads
self.kv_attn = Conv1D(2 * self.head_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# query: (b, num_heads * sq, head_dim)
# key: (b, head_dim, sk)
# value: (b, sk, head_dim)
batch_size = query.size(0)
query_length = query.size(1) // self.num_heads
key_length = key.size(2)
# (b, num_heads * sq, head_dim) x (b, head_dim, sk) -> (b, num_heads * sq, sk)
attn_weights = torch.bmm(query, key)
# -> (b, num_heads, sq, sk)
attn_weights = attn_weights.view(batch_size, self.num_heads, query_length, key_length)
if self.scale_attn_weights:
attn_weights = attn_weights / torch.tensor(
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# (b, num_heads, sq, sk) -> (b, num_heads * sq, sk)
_attn_weights = attn_weights.view(batch_size, self.num_heads * query_length, key_length)
# (b, num_heads * sq, sk) x (b, sk, head_dim) -> (b, num_heads * sq, head_dim)
attn_output = torch.bmm(_attn_weights, value)
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
return attn_output, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
# Compute Scale Factor
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
if self.scale_attn_by_inverse_layer_idx:
scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
if attn_weights.dtype != torch.float32:
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
raise NotImplementedError("Cross-attention not implemented for MQA")
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query = self.q_attn(hidden_states)
key, value = self.kv_attn(hidden_states).split(self.head_dim, dim=2)
batch_size, seq_length = query.shape[:2]
# (query_length, batch, num_heads, head_dim)
# (batch, num_heads * query_length, head_dim)\
# (batch, query_length, hidden_size) -> (batch, num_heads, query_length, head_dim)
query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).permute([0, 2, 1, 3])
# -> (batch, num_heads * query_length, head_dim)
query = query.reshape(batch_size, self.num_heads * seq_length, self.head_dim)
# (batch, query_length, hidden_size) -> (batch, query_length * num_heads, head_dim)
# query = query.view(
# batch_size, seq_length, self.num_heads, self.head_dim,
# ).reshape(
# batch_size, seq_length * self.num_heads, self.head_dim
# )
key = key.permute(0, 2, 1) # (batch_size, head_dim, seq_length)
# value (batch_size, seq_length, head_dim)
if layer_past is not None:
past_key, past_value = layer_past
# Concatenate on sequence dimension
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
if self.reorder_and_upcast_attn:
raise NotImplementedError("Reorder and upcast attention not implemented for MQA")
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
# inherit from gpt_modeling.py, and override `attn` module
class GPT2CustomBlock(GPT2Block):
def __init__(self, config: GPT2CustomConfig, layer_idx=None):
super().__init__(config, layer_idx)
# Override attention module if using multiquery
if config.attention_head_type == MULTI_QUERY:
self.attn = GPT2MQAttention(config, layer_idx=layer_idx)
if config.add_cross_attention:
raise NotImplementedError("Cross-attention not implemented for MQA")
# inherit from gpt_modeling.py and override `__init__` method
class GPT2CustomModel(GPT2Model):
config_class = GPT2CustomConfig
def __init__(self, config):
GPT2PreTrainedModel.__init__(self, config)
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2CustomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class GPT2LMHeadCustomModel(GPT2LMHeadModel):
config_class = GPT2CustomConfig
def __init__(self, config):
GPT2PreTrainedModel.__init__(self, config)
self.transformer = GPT2CustomModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init() |