sharpenb's picture
c4271efb8b34156b208ac18688c9bb1865f22810ee14f973918b406b89986d17
9043ce3 verified
raw
history blame
24.9 kB
import math
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
logging,
)
from typing import List, Optional, Tuple, Union
from .configuration_gpt_refact import GPTRefactConfig
logger = logging.get_logger(__name__)
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, softmax_dtype: torch.dtype
):
input_dtype = x.dtype
x = x.to(softmax_dtype)
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def upcast_softmax(x: torch.Tensor, softmax_dtype: torch.dtype):
input_dtype = x.dtype
x = x.to(softmax_dtype)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def _get_slopes(attn_heads: int, dev: torch.device) -> torch.Tensor:
"""
## Get head-specific slope $m$ for each head
* `n_heads` is the number of heads in the attention layer $n$
The slope for first head is
$$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$
The slopes for the rest of the heads are in a geometric series with a ratio same as above.
For instance when the number of heads is $8$ the slopes are
$$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$
"""
# Get the closest power of 2 to `n_heads`.
# If `n_heads` is not a power of 2, then we first calculate slopes to the closest (smaller) power of 2,
# and then add the remaining slopes.
n = 2 ** math.floor(math.log(attn_heads, 2))
# $2^{-\frac{8}{n}}$
m_0 = 2.0 ** (-8.0 / n)
# $2^{-1\frac{8}{n}}, 2^{-2 \frac{8}{n}}, 2^{-3 \frac{8}{n}}, \dots$
m = torch.pow(m_0, torch.arange(1, 1 + n, device=dev))
# If `n_heads` is not a power of 2, then we add the remaining slopes.
# We calculate the remaining slopes for $n * 2$ (avoiding slopes added previously).
# And pick the slopes upto `n_heads`.
if n < attn_heads:
# $2^{-\frac{8}{2n}}$
m_hat_0 = 2.0 ** (-4.0 / n)
# $2^{-1\frac{8}{2n}}, 2^{-3 \frac{8}{2n}}, 2^{-5 \frac{8}{2n}}, \dots$
# Note that we take steps by $2$ to avoid slopes added previously.
m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (attn_heads - n), 2, device=dev))
# Concatenate the slopes with the remaining slopes.
m = torch.cat([m, m_hat])
return m
@torch.jit.script
def get_alibi_biases(
B: int,
T: int,
attn_heads: int,
dev: torch.device,
dtype: torch.dtype) -> torch.Tensor:
"""
## Calculate the attention biases matrix
* `n_heads` is the number of heads in the attention layer
* `mask` is the attention mask of shape `[seq_len_q, seq_len_k]`
This returns a matrix of shape `[seq_len_q, seq_len_k, n_heads, ]` with ALiBi attention biases.
"""
# Get slopes $m$ for each head
mask = torch.ones((T, T), device=dev, dtype=torch.bool)
m = _get_slopes(attn_heads, dev).to(dtype)
# Calculate distances $[0, 1, \dots, N]$
# Here we calculate the distances using the mask.
#
# Since it's causal mask we can just use $[0, 1, \dots, N]$ too.
# `distance = torch.arange(mask.shape[1], dtype=torch.long, device=mask.device)[None, :]`
distance = mask.cumsum(dim=-1).to(dtype)
# Multiply them pair-wise to get the AliBi bias matrix
biases = distance[:, :, None] * m[None, None, :]
biases = biases.permute(2, 0, 1)[None, :, :T, :T]
return biases.contiguous()
class Attention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.mask_value = None
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.kv_attn_heads = 1
self.scale_factor = self.head_dim ** -0.5
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.layer_idx = layer_idx
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
self.attention_bias_in_fp32 = config.attention_bias_in_fp32
self.q = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.kv = nn.Linear(self.embed_dim, self.head_dim * 2, bias=False)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def _get_mask_value(self, device, dtype):
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
return self.mask_value
def _attn(self, query, key, value, attention_mask=None, alibi=None):
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
mask_value = self._get_mask_value(query.device, softmax_dtype)
upcast = dtype != softmax_dtype
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-1)
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
alibi = alibi.transpose(2, 1).reshape(alibi.shape[0], -1, alibi.shape[-1])
initial_dtype = query.dtype
new_dtype = torch.float32 if self.attention_bias_in_fp32 else initial_dtype
attn_weights = alibi.baddbmm(
batch1=query.to(new_dtype),
batch2=key.to(new_dtype),
beta=1,
alpha=self.scale_factor
).view(attn_shape).to(initial_dtype)
if upcast:
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Sub-optimal when the key length is not a multiple of 8.
if attention_mask is None:
attn_weights = upcast_softmax(attn_weights, softmax_dtype)
else:
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, softmax_dtype)
else:
if attention_mask is not None:
# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
attn_weights = torch.where(attention_mask, attn_weights, mask_value)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
return attn_output, attn_weights
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
query = self.q(hidden_states)
kv = self.kv(hidden_states)
key, value = kv.split(self.head_dim, dim=-1)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, alibi)
attn_output = self.c_proj(attn_output)
outputs = (attn_output, present)
if output_attentions:
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class MLP(nn.Module):
def __init__(self, intermediate_size, config, multiple_of: int = 256):
super().__init__()
embed_dim = config.hidden_size
hidden_dim = intermediate_size
hidden_dim = int(2 * hidden_dim / 3)
self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.gate_up_proj = nn.Linear(embed_dim, self.hidden_dim * 2, bias=False)
self.c_proj = nn.Linear(self.hidden_dim, embed_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
up_proj = self.gate_up_proj(x)
x1, x2 = torch.split(up_proj, self.hidden_dim, dim=-1)
x = self.c_proj(F.silu(x1) * x2)
return x
class LayerNormNoBias(nn.Module):
def __init__(self, shape: int, eps: float = 1e-5):
super().__init__()
self.shape = (shape,)
self.eps = eps
self.weight = nn.Parameter(torch.empty(self.shape))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.shape, self.weight, None, self.eps)
class GPTRefactBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(config, layer_idx=layer_idx)
self.ln_2 = LayerNormNoBias(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(self.inner_dim, config)
def forward(
self,
hidden_states: Optional[Tuple[torch.Tensor]],
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:
hidden_states_norm = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states_norm,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
mix = attn_output + hidden_states
norm_mix = self.ln_2(mix)
feed_forward_hidden_states = self.mlp(norm_mix)
# residual connection
hidden_states = mix + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions, cross_attentions)
class GPTRefactPreTrainedModel(PreTrainedModel):
config_class = GPTRefactConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTRefactBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
if isinstance(module, (MLP, Attention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
)
module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, LayerNormNoBias):
module.weight.data.fill_(1.0)
class GPTRefactModel(GPTRefactPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.multi_query = config.multi_query
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.h = nn.ModuleList([GPTRefactBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.max_positions = config.max_position_embeddings
self.attention_bias_in_fp32 = config.attention_bias_in_fp32
self.register_buffer(
"bias", torch.tril(torch.ones((self.max_positions, self.max_positions), dtype=torch.bool)),
persistent=False
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.wte
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
query_length = input_shape[-1]
seq_length_with_past = past_length + query_length
# Self-attention mask.
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if attention_mask is not None:
self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(
dtype=torch.bool, device=self_attention_mask.device
)
# MQA models: (batch_size, query_length, n_heads, key_length)
attention_mask = self_attention_mask.unsqueeze(2)
hidden_states = self.wte(input_ids) if inputs_embeds is None else inputs_embeds
alibi_dtype = torch.float32 if self.attention_bias_in_fp32 else self.wte.weight.dtype
alibi = get_alibi_biases(hidden_states.shape[0], seq_length_with_past,
self.num_heads, device, alibi_dtype)[:, :, -query_length:, :]
output_shape = input_shape + (hidden_states.size(-1),)
presents = [] if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
alibi
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache:
presents.append(outputs[1])
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
_tied_weights_keys = ["lm_head.weight", "ln_f.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = GPTRefactModel(config)
self.ln_f = LayerNormNoBias(self.transformer.embed_dim, eps=config.layer_norm_epsilon)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# gradient checkpointing support for lower versions of transformers
import transformers
from packaging import version
def _set_gradient_checkpointing(module, value=False):
if isinstance(module, GPTRefactModel):
module.gradient_checkpointing = value
v = version.parse(transformers.__version__)
if v.major <= 4 and v.minor < 35:
self._set_gradient_checkpointing = _set_gradient_checkpointing
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if past_key_values is not None:
model_inputs = {"input_ids": input_ids[..., -1:]}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
x = self.ln_f(hidden_states)
lm_logits = self.lm_head(x)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)