granite-3b-code-base / modeling_granite.py
Mayank Mishra
update script
6bb0180
raw
history blame
No virus
58.3 kB
import math
import numbers
import warnings
from enum import Enum
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DynamicCache, PreTrainedModel
from transformers.activations import get_activation as get_base_activation
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.utils import is_flash_attn_2_available
from .configuration_granite import GraniteConfig
class PositionEmbeddingType(Enum):
learned_absolute = "learned_absolute"
alibi = "alibi"
rope = "rope"
class AttentionHeadType(Enum):
mha = "mha"
mqa = "mqa"
gqa = "gqa"
if is_flash_attn_2_available():
from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
num_groups = num_heads // num_key_value_heads
# mha
if num_groups == 1:
return x
# mqa
if num_key_value_heads == 1:
return x.expand(-1, num_heads, -1, -1)
# gqa
return x.repeat_interleave(num_groups, dim=1)
##################################################
# activation functions
_GLU_BASE_MAPPING = {
"geglu": "gelu",
"miglu": "mish",
"mishglu": "mish",
"swiglu": "swish",
}
class GLUActivation(nn.Module):
def __init__(self, base_activation: nn.Module) -> None:
super().__init__()
self.base_activation = base_activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.chunk(2, dim=-1)
return x[0] * self.base_activation(x[1])
def is_glu(name: str) -> bool:
return name.endswith("glu")
def get_activation_function(name: str) -> nn.Module:
if is_glu(name):
# for glu and sigmoid_glu, we directly return the pytorch's GLU
if name in ["glu", "sigmoid_glu"]:
activation_function = nn.modules.GLU()
else:
if name in _GLU_BASE_MAPPING:
name = _GLU_BASE_MAPPING[name]
elif name.endswith("_glu"):
name = name.rstrip("_glu")
else:
raise ValueError("invalid activation function")
base_activation = get_base_activation(name)
activation_function = GLUActivation(base_activation)
else:
activation_function = get_base_activation(name)
return activation_function
##################################################
# normalization functions
class RMSNorm(nn.Module):
def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.eps = eps
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = normalized_shape
def forward(self, input: torch.Tensor) -> torch.Tensor:
input_dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + self.eps)
return self.weight * input.to(input_dtype)
def extra_repr(self) -> str:
return f"{self.normalized_shape}, eps={self.eps}"
def reset_parameters(self) -> None:
nn.init.ones_(self.weight)
_NORMALIZATION_FUNCTIONS = {
"layernorm": nn.LayerNorm,
"rmsnorm": RMSNorm,
}
def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module:
if name in _NORMALIZATION_FUNCTIONS:
return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps)
raise ValueError(f"unexpected `normalization_function` {name}")
##################################################
# attention modules
class GraniteAttention(nn.Module):
def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None:
super().__init__()
self.causal = causal
self.hidden_size = config.n_embd
self.num_heads = config.n_head
self.num_key_value_heads = config.num_key_value_heads
self.add_bias = config.add_bias
assert (
self.hidden_size % self.num_heads == 0
), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})"
self.head_dim = self.hidden_size // self.num_heads
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self.scale_attn_weights = config.scale_attn_weights
self.attention_multiplier = config.attention_multiplier
self.layer_idx = layer_idx
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
if self.attention_head_type == AttentionHeadType.mha:
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_heads
assert (
self.num_heads == self.num_key_value_heads
), f"{self.__class__.__name__} should have same number of heads for query, keys and values"
elif self.attention_head_type == AttentionHeadType.gqa:
assert (
self.num_key_value_heads is not None
), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
assert self.num_heads % self.num_key_value_heads == 0, (
f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` "
f"({self.num_key_value_heads})"
)
elif self.attention_head_type == AttentionHeadType.mqa:
if self.num_key_value_heads is None:
self.num_key_value_heads = 1
assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values"
else:
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
# note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA
# (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features
self.c_attn = nn.Linear(
self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias
)
self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias)
self.attn_pdrop = config.attn_pdrop
self.resid_pdrop = config.resid_pdrop
self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop)
self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop)
def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
# the output of following is a tuple if using MQA with tensor parallel
hidden_states = self.c_attn(hidden_states)
# ==========================================================================================
# hidden_states -> (batch_size, query_length, [num_heads + num_key_value_heads * 2] * head_dim)
# ==========================================================================================
# for MHA, we can get away with doing just 1 transpose which is not true for GQA
if self.attention_head_type == AttentionHeadType.mha:
query, key, value = self._prepare_qkv_for_forward_mha(hidden_states)
elif self.attention_head_type == AttentionHeadType.gqa:
query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states)
elif self.attention_head_type == AttentionHeadType.mqa:
query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states)
else:
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
return query, key, value
def _prepare_qkv_for_forward_mha(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1)
hidden_states = hidden_states.transpose(1, 2)
query, key, value = hidden_states.chunk(3, dim=-1)
return query, key, value
def _prepare_qkv_for_forward_gqa(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1)
query, key, value = hidden_states.split(
((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
)
# this needs to be a reshape instead of view sadly
query = query.reshape(batch_size, query_length, -1, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
return query, key, value
def _prepare_qkv_for_forward_mqa(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)
query = query.view(batch_size, query_length, self.num_heads, -1)
query = query.transpose(1, 2)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
return query, key, value
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
key = key.transpose(-1, -2)
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
if self.scale_attn_weights:
if self.attention_multiplier is None:
scale_factor = 1 / self.head_dim**0.5
else:
scale_factor = self.attention_multiplier
else:
scale_factor = 1
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, head_dim, key_length)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
batch_size = query.shape[0]
query_length = query.shape[2]
key_length = key.shape[-1]
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
# Always copies
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
# ==========================================================================================
# query -> (batch_size * num_heads, query_length, head_dim)
# key -> (batch_size * num_heads, head_dim, key_length)
# value -> (batch_size, num_heads, key_length, head_dim)
# ==========================================================================================
attn_weights = torch.empty(
(batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype
)
attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view(
batch_size, self.num_heads, query_length, key_length
)
# ==========================================================================================
# attn_weights -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
attn_weights = attn_weights.to(softmax_dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1).to(dtype)
attn_weights = self.attn_dropout(attn_weights)
# ==========================================================================================
# value -> (batch_size, num_heads, key_length, head_dim)
# attn_weights -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
attn_output = torch.matmul(attn_weights, value)
# ==========================================================================================
# attn_output -> (batch_size, num_heads, query_length, head_dim)
# ==========================================================================================
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
class GraniteSDPA(GraniteAttention):
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_heads, key_length, head_dim)
# value -> (batch_size, num_heads, key_length, head_dim)
# ==========================================================================================
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0,
is_causal=self.causal if attention_mask is None else False,
scale=self.attention_multiplier if self.scale_attn_weights else 1,
)
# ==========================================================================================
# attn_output -> (batch_size, num_heads, query_length, head_dim)
# ==========================================================================================
batch_size = attn_output.shape[0]
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
class GraniteFlashAttention2(GraniteAttention):
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
# TODO avoid this extra transpose
query = query.transpose(1, 2)
if self.attention_head_type == AttentionHeadType.mqa:
key = key.squeeze(1).unsqueeze(2)
value = value.squeeze(1).unsqueeze(2)
else:
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# ==========================================================================================
# query -> (batch_size, query_length, num_heads, head_dim)
# key -> (batch_size, key_length, num_heads, head_dim)
# value -> (batch_size, key_length, num_heads, head_dim)
# ==========================================================================================
batch_size, query_length = query.shape[:2]
key_length = key.shape[1]
indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask)
key = IndexFirstAxis.apply(
key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
)
value = IndexFirstAxis.apply(
value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
)
if query_length == key_length:
query = IndexFirstAxis.apply(
query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_q = max_seqlen_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query = query.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask)
# ==========================================================================================
# query -> (total_q, num_heads, head_dim)
# key -> (total_q, num_heads, head_dim)
# value -> (total_q, num_heads, head_dim)
# ==========================================================================================
attn_output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=self.attn_pdrop if self.training else 0,
softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1,
causal=self.causal,
)
# ==========================================================================================
# attn_output -> (total_q, num_heads, head_dim)
# ==========================================================================================
attn_output = pad_input(attn_output, indices_q, batch_size, query_length)
attn_output = attn_output.view(batch_size, query_length, -1)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
_ATTENTION_MODULES = {
"eager": GraniteAttention,
"sdpa": GraniteSDPA,
"flash_attention_2": GraniteFlashAttention2,
}
def get_attention_module(
config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int
) -> GraniteAttention:
if attention_implementation in _ATTENTION_MODULES:
return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx)
raise ValueError(f"unexpected `attention_implementation` {attention_implementation}")
##################################################
# position embeddings
class Alibi(nn.Module):
def __init__(self, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.reset_parameters()
def forward(
self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`)
num_heads (int): `num_heads` for the model
batch_size (int): `batch_size`
key_length (int): `key_length`
device (torch.device): device for the tensors
dtype (torch.dtype): dtype to use for the tensors
Returns:
torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`)
"""
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
if attention_mask is None:
arange_tensor = (
torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1)
)
else:
arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1)
alibi = self.slopes.unsqueeze(1) * arange_tensor
return alibi.to(dtype)
def reset_parameters(self) -> None:
closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != self.num_heads:
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
self.register_buffer("slopes", slopes, persistent=False)
class RoPE(nn.Module):
def __init__(
self,
head_dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.mscale = 1
self.reset_parameters()
def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
cos = self.cos_cached[:seq_len].to(dtype)
sin = self.sin_cached[:seq_len].to(dtype)
return cos, sin
def reset_parameters(self) -> None:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
@torch.no_grad()
def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = cos_sin
x = (x * cos) + (_rotate_half(x) * sin)
return x
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
##################################################
# MLP
class GraniteMLP(nn.Module):
def __init__(self, config: GraniteConfig) -> None:
super().__init__()
hidden_size = config.n_embd
intermediate_size = config.n_inner
activation_function = config.activation_function
add_bias = config.add_bias
residual_dropout = config.resid_pdrop
self.c_fc = nn.Linear(
hidden_size,
2 * intermediate_size if is_glu(activation_function) else intermediate_size,
bias=add_bias,
)
self.act = get_activation_function(activation_function)
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias)
self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
##################################################
# transformer layer
class GraniteBlock(nn.Module):
def __init__(
self,
config: GraniteConfig,
attention_implementation: str,
layer_idx: Optional[int] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
self.inner_dim = config.n_inner
self.layer_idx = layer_idx
self.ln_1 = get_normalization_function(
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = get_attention_module(config, True, attention_implementation, layer_idx)
self.ln_2 = get_normalization_function(
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = GraniteMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
rope_cos_sin=rope_cos_sin,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
##################################################
# model classes
class GranitePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GraniteConfig
base_model_prefix = "transformer"
causal = True
_no_split_modules = ["GraniteBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
def __init__(self, config: GraniteConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.attention_implementation = self.config._attn_implementation
self._use_eager_attention = self.attention_implementation == "eager"
self._use_sdpa = self.attention_implementation == "sdpa"
self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2"
self.initializer_range = config.initializer_range
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
module.reset_parameters()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
if module.padding_idx is not None:
module.weight[module.padding_idx].zero_()
class GraniteModel(GranitePreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
mask_value = None
def __init__(self, config: GraniteConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
assert (
self.embed_dim % self.num_heads == 0
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"
self.head_dim = self.embed_dim // self.num_heads
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.ln_f = get_normalization_function(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
)
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
elif self.position_embedding_type == PositionEmbeddingType.alibi:
assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention"
self.alibi = Alibi(self.num_heads)
elif self.position_embedding_type == PositionEmbeddingType.rope:
self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta)
else:
raise NotImplementedError()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.wte
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.wte = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
(
output_hidden_states,
use_cache,
return_dict,
input_shape,
hidden_states,
attention_mask,
position_ids,
rope_cos_sin,
past_key_values,
) = self._prepare_a_bunch_of_stuff(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# ==========================================================================================
# flash:
# attention_mask -> (batch_size, key_length)
# else:
# attention_mask -> (batch_size, 1, query_length, key_length)
# ==========================================================================================
output_shape = input_shape + (hidden_states.size(-1),)
past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values
all_hidden_states = () if output_hidden_states else None
for block in self.h:
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = block(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
rope_cos_sin=rope_cos_sin,
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
)
def _get_position_ids(
self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device
) -> torch.Tensor:
if attention_mask is not None and len(attention_mask.shape) == 2:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
if past_length > 0:
position_ids = position_ids[:, past_length:key_length:]
else:
position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, query_length)
return position_ids
def _get_alibi_bias(
self,
attention_mask: torch.Tensor,
batch_size: int,
query_length: int,
key_length: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if self.position_embedding_type != PositionEmbeddingType.alibi:
return None
alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, key_length)
# ==========================================================================================
alibi_bias = alibi_bias.unsqueeze(2)
if query_length != 1:
alibi_bias = alibi_bias.expand(-1, -1, query_length, -1)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
return alibi_bias
def _get_rope_cos_sin(
self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type == PositionEmbeddingType.rope:
cos, sin = self.rope(key_length, dtype=dtype, device=device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
return cos, sin
def _prepare_causal_attention_mask(
self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device
) -> torch.Tensor:
past_length = key_length - query_length
# ==========================================================================================
# attention_mask -> (batch_size, key_length)
# ==========================================================================================
if query_length > 1:
# (query_length, key_length)
causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device)
causal_mask[:, past_length:] = torch.tril(
torch.ones(query_length, query_length, dtype=torch.bool, device=device)
)
if past_length > 0:
causal_mask[:, :past_length] = True
# (query_length, key_length) -> (1, query_length, key_length)
causal_mask = causal_mask.unsqueeze(0)
if attention_mask is None:
# (1, query_length, key_length) -> (batch_size, query_length, key_length)
causal_mask = causal_mask.expand(batch_size, -1, -1)
else:
# (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length)
causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool)
else:
if attention_mask is None:
# (batch_size, query_length, key_length)
causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device)
else:
# (batch_size, query_length, key_length)
causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device)
# ==========================================================================================
# attention_mask -> (batch_size, query_length, key_length)
# ==========================================================================================
causal_mask = causal_mask.unsqueeze(1)
# ==========================================================================================
# attention_mask -> (batch_size, 1, query_length, key_length)
# ==========================================================================================
return causal_mask
def _get_initial_hidden_state(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: torch.Tensor,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
inputs_embeds = inputs_embeds + self.wpe(position_ids)
if token_type_ids is not None:
inputs_embeds = inputs_embeds + self.wte(token_type_ids)
inputs_embeds = self.drop(inputs_embeds)
return inputs_embeds
def _prepare_a_bunch_of_stuff(
self,
input_ids: torch.Tensor,
past_key_values: DynamicCache,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
use_cache: bool,
output_hidden_states: bool,
return_dict: bool,
) -> Tuple[
bool,
bool,
bool,
torch.Size,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[Tuple[torch.Tensor, torch.Tensor]],
DynamicCache,
]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = self.config.use_cache if use_cache is None else use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
# TODO special handling for padding free transformer needed here if we support inputs_embeds argument
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size = input_shape[0]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if self.position_embedding_type == PositionEmbeddingType.alibi:
if position_ids is not None:
warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> None or (batch_size, key_length)
# ==========================================================================================
past_length = 0 if past_key_values is None else past_key_values.get_seq_length()
query_length = input_shape[-1]
key_length = past_length + query_length
if position_ids is None:
position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device)
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> (batch_size, query_length)
# ==========================================================================================
hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids)
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
alibi_bias = self._get_alibi_bias(
attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype
)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
rope_cos_sin = self._get_rope_cos_sin(
key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device
)
# ==========================================================================================
# rope_cos_sin -> 2 * (key_length, head_dim)
# ==========================================================================================
# prepare causal mask only if not using flash attention
if self._use_flash_attention_2:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
elif self._use_sdpa:
# we use the causal/non-causal argument of SDPA for attention in this case
if attention_mask is not None:
attention_mask = self._prepare_causal_attention_mask(
attention_mask, batch_size, query_length, key_length, device
)
attention_mask = torch.where(
attention_mask,
~attention_mask if alibi_bias is None else alibi_bias,
self._get_mask_value(attention_mask.device, hidden_states.dtype),
)
else:
attention_mask = self._prepare_causal_attention_mask(
attention_mask, batch_size, query_length, key_length, device
)
attention_mask = torch.where(
attention_mask,
~attention_mask if alibi_bias is None else alibi_bias,
self._get_mask_value(attention_mask.device, hidden_states.dtype),
)
return (
output_hidden_states,
use_cache,
return_dict,
input_shape,
hidden_states,
attention_mask,
position_ids,
rope_cos_sin,
past_key_values,
)
def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device)
return self.mask_value
class GraniteForCausalLM(GranitePreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: GraniteConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.transformer = GraniteModel(config, **kwargs)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.transformer.wte = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.lm_head = new_embeddings
# FIXME typing
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values.get_seq_length()
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[Union[torch.Tensor]] = None,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[Union[torch.Tensor]] = None,
position_ids: Optional[Union[torch.Tensor]] = None,
inputs_embeds: Optional[Union[torch.Tensor]] = None,
labels: Optional[Union[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> None or (batch_size, key_length)
# ==========================================================================================
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
# Shift so that tokens < n predict n
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)