|
import math |
|
import numbers |
|
import warnings |
|
from enum import Enum |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import DynamicCache, PreTrainedModel |
|
from transformers.activations import get_activation as get_base_activation |
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions |
|
from transformers.utils import is_flash_attn_2_available |
|
|
|
from .configuration_granite import GraniteConfig |
|
|
|
|
|
class PositionEmbeddingType(Enum): |
|
learned_absolute = "learned_absolute" |
|
alibi = "alibi" |
|
rope = "rope" |
|
|
|
|
|
class AttentionHeadType(Enum): |
|
mha = "mha" |
|
mqa = "mqa" |
|
gqa = "gqa" |
|
|
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input |
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func |
|
|
|
|
|
|
|
def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
return indices, cu_seqlens, max_seqlen_in_batch |
|
|
|
|
|
def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor: |
|
num_groups = num_heads // num_key_value_heads |
|
|
|
|
|
if num_groups == 1: |
|
return x |
|
|
|
|
|
if num_key_value_heads == 1: |
|
return x.expand(-1, num_heads, -1, -1) |
|
|
|
|
|
return x.repeat_interleave(num_groups, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
_GLU_BASE_MAPPING = { |
|
"geglu": "gelu", |
|
"miglu": "mish", |
|
"mishglu": "mish", |
|
"swiglu": "swish", |
|
} |
|
|
|
|
|
class GLUActivation(nn.Module): |
|
def __init__(self, base_activation: nn.Module) -> None: |
|
super().__init__() |
|
self.base_activation = base_activation |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.chunk(2, dim=-1) |
|
return x[0] * self.base_activation(x[1]) |
|
|
|
|
|
def is_glu(name: str) -> bool: |
|
return name.endswith("glu") |
|
|
|
|
|
def get_activation_function(name: str) -> nn.Module: |
|
if is_glu(name): |
|
|
|
if name in ["glu", "sigmoid_glu"]: |
|
activation_function = nn.modules.GLU() |
|
else: |
|
if name in _GLU_BASE_MAPPING: |
|
name = _GLU_BASE_MAPPING[name] |
|
elif name.endswith("_glu"): |
|
name = name.rstrip("_glu") |
|
else: |
|
raise ValueError("invalid activation function") |
|
|
|
base_activation = get_base_activation(name) |
|
activation_function = GLUActivation(base_activation) |
|
else: |
|
activation_function = get_base_activation(name) |
|
|
|
return activation_function |
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.eps = eps |
|
|
|
if isinstance(normalized_shape, numbers.Integral): |
|
normalized_shape = (normalized_shape,) |
|
self.normalized_shape = normalized_shape |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
input_dtype = input.dtype |
|
|
|
input = input.to(torch.float32) |
|
variance = input.pow(2).mean(-1, keepdim=True) |
|
input = input * torch.rsqrt(variance + self.eps) |
|
|
|
return self.weight * input.to(input_dtype) |
|
|
|
def extra_repr(self) -> str: |
|
return f"{self.normalized_shape}, eps={self.eps}" |
|
|
|
def reset_parameters(self) -> None: |
|
nn.init.ones_(self.weight) |
|
|
|
|
|
_NORMALIZATION_FUNCTIONS = { |
|
"layernorm": nn.LayerNorm, |
|
"rmsnorm": RMSNorm, |
|
} |
|
|
|
|
|
def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module: |
|
if name in _NORMALIZATION_FUNCTIONS: |
|
return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps) |
|
|
|
raise ValueError(f"unexpected `normalization_function` {name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraniteAttention(nn.Module): |
|
def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None: |
|
super().__init__() |
|
|
|
self.causal = causal |
|
self.hidden_size = config.n_embd |
|
self.num_heads = config.n_head |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.add_bias = config.add_bias |
|
|
|
assert ( |
|
self.hidden_size % self.num_heads == 0 |
|
), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})" |
|
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.attention_head_type = AttentionHeadType(config.attention_head_type) |
|
|
|
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) |
|
self.scale_attn_weights = config.scale_attn_weights |
|
self.attention_multiplier = config.attention_multiplier |
|
|
|
self.layer_idx = layer_idx |
|
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 |
|
self.scale_attention_softmax_in_fp32 = ( |
|
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 |
|
) |
|
|
|
if self.attention_head_type == AttentionHeadType.mha: |
|
if self.num_key_value_heads is None: |
|
self.num_key_value_heads = self.num_heads |
|
|
|
assert ( |
|
self.num_heads == self.num_key_value_heads |
|
), f"{self.__class__.__name__} should have same number of heads for query, keys and values" |
|
elif self.attention_head_type == AttentionHeadType.gqa: |
|
assert ( |
|
self.num_key_value_heads is not None |
|
), "`num_key_value_heads` needs to be specified with GroupedQueryAttention" |
|
|
|
assert self.num_heads % self.num_key_value_heads == 0, ( |
|
f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` " |
|
f"({self.num_key_value_heads})" |
|
) |
|
elif self.attention_head_type == AttentionHeadType.mqa: |
|
if self.num_key_value_heads is None: |
|
self.num_key_value_heads = 1 |
|
|
|
assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values" |
|
else: |
|
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") |
|
|
|
|
|
|
|
self.c_attn = nn.Linear( |
|
self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias |
|
) |
|
self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias) |
|
|
|
self.attn_pdrop = config.attn_pdrop |
|
self.resid_pdrop = config.resid_pdrop |
|
|
|
self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) |
|
self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) |
|
|
|
def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.c_attn(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.attention_head_type == AttentionHeadType.mha: |
|
query, key, value = self._prepare_qkv_for_forward_mha(hidden_states) |
|
elif self.attention_head_type == AttentionHeadType.gqa: |
|
query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states) |
|
elif self.attention_head_type == AttentionHeadType.mqa: |
|
query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states) |
|
else: |
|
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return query, key, value |
|
|
|
def _prepare_qkv_for_forward_mha( |
|
self, hidden_states: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
batch_size, query_length = hidden_states.shape[:-1] |
|
|
|
hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1) |
|
hidden_states = hidden_states.transpose(1, 2) |
|
|
|
query, key, value = hidden_states.chunk(3, dim=-1) |
|
|
|
return query, key, value |
|
|
|
def _prepare_qkv_for_forward_gqa( |
|
self, hidden_states: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
batch_size, query_length = hidden_states.shape[:-1] |
|
|
|
hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) |
|
|
|
query, key, value = hidden_states.split( |
|
((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 |
|
) |
|
|
|
|
|
query = query.reshape(batch_size, query_length, -1, self.head_dim) |
|
|
|
query = query.transpose(1, 2) |
|
key = key.transpose(1, 2) |
|
value = value.transpose(1, 2) |
|
|
|
return query, key, value |
|
|
|
def _prepare_qkv_for_forward_mqa( |
|
self, hidden_states: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
batch_size, query_length = hidden_states.shape[:-1] |
|
|
|
query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) |
|
|
|
query = query.view(batch_size, query_length, self.num_heads, -1) |
|
|
|
query = query.transpose(1, 2) |
|
key = key.unsqueeze(1) |
|
value = value.unsqueeze(1) |
|
|
|
return query, key, value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
query, key, value = self._prepare_qkv_for_forward(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.position_embedding_type == PositionEmbeddingType.rope: |
|
query = apply_rotary_pos_emb(query, rope_cos_sin) |
|
key = apply_rotary_pos_emb(key, rope_cos_sin) |
|
|
|
if past_key_values is not None: |
|
key, value = past_key_values.update(key, value, self.layer_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key = key.transpose(-1, -2) |
|
|
|
dtype = query.dtype |
|
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype |
|
|
|
if self.scale_attn_weights: |
|
if self.attention_multiplier is None: |
|
scale_factor = 1 / self.head_dim**0.5 |
|
else: |
|
scale_factor = self.attention_multiplier |
|
else: |
|
scale_factor = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = query.shape[0] |
|
query_length = query.shape[2] |
|
key_length = key.shape[-1] |
|
|
|
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads) |
|
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads) |
|
|
|
|
|
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) |
|
|
|
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_weights = torch.empty( |
|
(batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype |
|
) |
|
|
|
attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view( |
|
batch_size, self.num_heads, query_length, key_length |
|
) |
|
|
|
|
|
|
|
|
|
|
|
attn_weights = attn_weights.to(softmax_dtype) |
|
|
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask |
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1).to(dtype) |
|
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = torch.matmul(attn_weights, value) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2) |
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class GraniteSDPA(GraniteAttention): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
query, key, value = self._prepare_qkv_for_forward(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.position_embedding_type == PositionEmbeddingType.rope: |
|
query = apply_rotary_pos_emb(query, rope_cos_sin) |
|
key = apply_rotary_pos_emb(key, rope_cos_sin) |
|
|
|
if past_key_values is not None: |
|
key, value = past_key_values.update(key, value, self.layer_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads) |
|
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
dropout_p=self.attn_pdrop if self.training else 0, |
|
is_causal=self.causal if attention_mask is None else False, |
|
scale=self.attention_multiplier if self.scale_attn_weights else 1, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = attn_output.shape[0] |
|
attn_output = attn_output.transpose(1, 2) |
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
class GraniteFlashAttention2(GraniteAttention): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
query, key, value = self._prepare_qkv_for_forward(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.position_embedding_type == PositionEmbeddingType.rope: |
|
query = apply_rotary_pos_emb(query, rope_cos_sin) |
|
key = apply_rotary_pos_emb(key, rope_cos_sin) |
|
|
|
if past_key_values is not None: |
|
key, value = past_key_values.update(key, value, self.layer_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = query.transpose(1, 2) |
|
if self.attention_head_type == AttentionHeadType.mqa: |
|
key = key.squeeze(1).unsqueeze(2) |
|
value = value.squeeze(1).unsqueeze(2) |
|
else: |
|
key = key.transpose(1, 2) |
|
value = value.transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, query_length = query.shape[:2] |
|
key_length = key.shape[1] |
|
indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask) |
|
|
|
key = IndexFirstAxis.apply( |
|
key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k |
|
) |
|
value = IndexFirstAxis.apply( |
|
value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k |
|
) |
|
|
|
if query_length == key_length: |
|
query = IndexFirstAxis.apply( |
|
query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k |
|
) |
|
cu_seqlens_q = cu_seqlens_k |
|
max_seqlen_q = max_seqlen_k |
|
indices_q = indices_k |
|
elif query_length == 1: |
|
max_seqlen_q = 1 |
|
cu_seqlens_q = torch.arange( |
|
batch_size + 1, dtype=torch.int32, device=query.device |
|
) |
|
indices_q = cu_seqlens_q[:-1] |
|
query = query.squeeze(1) |
|
else: |
|
|
|
attention_mask = attention_mask[:, -query_length:] |
|
query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = flash_attn_varlen_func( |
|
query, |
|
key, |
|
value, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
dropout_p=self.attn_pdrop if self.training else 0, |
|
softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1, |
|
causal=self.causal, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = pad_input(attn_output, indices_q, batch_size, query_length) |
|
attn_output = attn_output.view(batch_size, query_length, -1) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output) |
|
|
|
return attn_output |
|
|
|
|
|
_ATTENTION_MODULES = { |
|
"eager": GraniteAttention, |
|
"sdpa": GraniteSDPA, |
|
"flash_attention_2": GraniteFlashAttention2, |
|
} |
|
|
|
|
|
def get_attention_module( |
|
config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int |
|
) -> GraniteAttention: |
|
if attention_implementation in _ATTENTION_MODULES: |
|
return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx) |
|
raise ValueError(f"unexpected `attention_implementation` {attention_implementation}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Alibi(nn.Module): |
|
def __init__(self, num_heads: int) -> None: |
|
super().__init__() |
|
self.num_heads = num_heads |
|
|
|
self.reset_parameters() |
|
|
|
def forward( |
|
self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype |
|
) -> torch.Tensor: |
|
""" |
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it |
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value |
|
`softmax(l+a) = softmax(l)`. Based on |
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 |
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. |
|
|
|
Args: |
|
attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`) |
|
num_heads (int): `num_heads` for the model |
|
batch_size (int): `batch_size` |
|
key_length (int): `key_length` |
|
device (torch.device): device for the tensors |
|
dtype (torch.dtype): dtype to use for the tensors |
|
|
|
Returns: |
|
torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
arange_tensor = ( |
|
torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1) |
|
) |
|
else: |
|
arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1) |
|
|
|
alibi = self.slopes.unsqueeze(1) * arange_tensor |
|
return alibi.to(dtype) |
|
|
|
def reset_parameters(self) -> None: |
|
closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) |
|
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32) |
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) |
|
slopes = torch.pow(base, powers) |
|
|
|
if closest_power_of_2 != self.num_heads: |
|
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32) |
|
num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2) |
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32) |
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) |
|
|
|
self.register_buffer("slopes", slopes, persistent=False) |
|
|
|
|
|
class RoPE(nn.Module): |
|
def __init__( |
|
self, |
|
head_dim: int, |
|
max_position_embeddings: int = 2048, |
|
base: int = 10000, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.head_dim = head_dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
self.mscale = 1 |
|
|
|
self.reset_parameters() |
|
|
|
def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if seq_len > self.max_seq_len_cached: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) |
|
|
|
cos = self.cos_cached[:seq_len].to(dtype) |
|
sin = self.sin_cached[:seq_len].to(dtype) |
|
|
|
return cos, sin |
|
|
|
def reset_parameters(self) -> None: |
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
self._set_cos_sin_cache( |
|
seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
|
) |
|
|
|
@torch.no_grad() |
|
def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False) |
|
|
|
|
|
def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
|
cos, sin = cos_sin |
|
x = (x * cos) + (_rotate_half(x) * sin) |
|
return x |
|
|
|
|
|
def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
x1, x2 = torch.chunk(x, 2, dim=-1) |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraniteMLP(nn.Module): |
|
def __init__(self, config: GraniteConfig) -> None: |
|
super().__init__() |
|
|
|
hidden_size = config.n_embd |
|
intermediate_size = config.n_inner |
|
activation_function = config.activation_function |
|
add_bias = config.add_bias |
|
residual_dropout = config.resid_pdrop |
|
|
|
self.c_fc = nn.Linear( |
|
hidden_size, |
|
2 * intermediate_size if is_glu(activation_function) else intermediate_size, |
|
bias=add_bias, |
|
) |
|
self.act = get_activation_function(activation_function) |
|
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias) |
|
self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.c_fc(hidden_states) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.c_proj(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraniteBlock(nn.Module): |
|
def __init__( |
|
self, |
|
config: GraniteConfig, |
|
attention_implementation: str, |
|
layer_idx: Optional[int] = None, |
|
) -> None: |
|
super().__init__() |
|
|
|
hidden_size = config.hidden_size |
|
self.inner_dim = config.n_inner |
|
self.layer_idx = layer_idx |
|
|
|
self.ln_1 = get_normalization_function( |
|
config.normalization_function, |
|
hidden_size, |
|
eps=config.layer_norm_epsilon, |
|
) |
|
self.attn = get_attention_module(config, True, attention_implementation, layer_idx) |
|
self.ln_2 = get_normalization_function( |
|
config.normalization_function, |
|
hidden_size, |
|
eps=config.layer_norm_epsilon, |
|
) |
|
self.mlp = GraniteMLP(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
residual = hidden_states |
|
hidden_states = self.ln_1(hidden_states) |
|
|
|
attn_output = self.attn( |
|
hidden_states, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
rope_cos_sin=rope_cos_sin, |
|
) |
|
|
|
|
|
hidden_states = attn_output + residual |
|
|
|
residual = hidden_states |
|
hidden_states = self.ln_2(hidden_states) |
|
|
|
feed_forward_hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
hidden_states = residual + feed_forward_hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GranitePreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = GraniteConfig |
|
base_model_prefix = "transformer" |
|
causal = True |
|
_no_split_modules = ["GraniteBlock"] |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_sdpa = True |
|
_supports_flash_attn_2 = True |
|
|
|
def __init__(self, config: GraniteConfig, *inputs, **kwargs): |
|
super().__init__(config, *inputs, **kwargs) |
|
|
|
self.attention_implementation = self.config._attn_implementation |
|
self._use_eager_attention = self.attention_implementation == "eager" |
|
self._use_sdpa = self.attention_implementation == "sdpa" |
|
self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2" |
|
|
|
self.initializer_range = config.initializer_range |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)): |
|
module.reset_parameters() |
|
elif isinstance(module, nn.Linear): |
|
nn.init.normal_(module.weight, mean=0, std=self.initializer_range) |
|
if module.bias is not None: |
|
module.bias.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, mean=0, std=self.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight[module.padding_idx].zero_() |
|
|
|
|
|
class GraniteModel(GranitePreTrainedModel): |
|
_keys_to_ignore_on_load_missing = ["attn.masked_bias"] |
|
mask_value = None |
|
|
|
def __init__(self, config: GraniteConfig, **kwargs) -> None: |
|
super().__init__(config, **kwargs) |
|
|
|
self.attention_head_type = AttentionHeadType(config.attention_head_type) |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
|
assert ( |
|
self.embed_dim % self.num_heads == 0 |
|
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})" |
|
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) |
|
|
|
self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) |
|
self.h = nn.ModuleList( |
|
[GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)] |
|
) |
|
self.ln_f = get_normalization_function( |
|
config.normalization_function, |
|
self.embed_dim, |
|
eps=config.layer_norm_epsilon, |
|
) |
|
|
|
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) |
|
|
|
if self.position_embedding_type == PositionEmbeddingType.learned_absolute: |
|
self.wpe = nn.Embedding(config.n_positions, self.embed_dim) |
|
elif self.position_embedding_type == PositionEmbeddingType.alibi: |
|
assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention" |
|
|
|
self.alibi = Alibi(self.num_heads) |
|
elif self.position_embedding_type == PositionEmbeddingType.rope: |
|
self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta) |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.wte |
|
|
|
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: |
|
self.wte = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
( |
|
output_hidden_states, |
|
use_cache, |
|
return_dict, |
|
input_shape, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
rope_cos_sin, |
|
past_key_values, |
|
) = self._prepare_a_bunch_of_stuff( |
|
input_ids=input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),) |
|
|
|
past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values |
|
all_hidden_states = () if output_hidden_states else None |
|
for block in self.h: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
hidden_states = block( |
|
hidden_states, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
rope_cos_sin=rope_cos_sin, |
|
) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
|
hidden_states = hidden_states.view(output_shape) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None) |
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values, |
|
hidden_states=all_hidden_states, |
|
) |
|
|
|
def _get_position_ids( |
|
self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device |
|
) -> torch.Tensor: |
|
if attention_mask is not None and len(attention_mask.shape) == 2: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 0) |
|
if past_length > 0: |
|
position_ids = position_ids[:, past_length:key_length:] |
|
else: |
|
position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) |
|
position_ids = position_ids.unsqueeze(0).view(-1, query_length) |
|
|
|
return position_ids |
|
|
|
def _get_alibi_bias( |
|
self, |
|
attention_mask: torch.Tensor, |
|
batch_size: int, |
|
query_length: int, |
|
key_length: int, |
|
device: torch.device, |
|
dtype: torch.dtype, |
|
) -> torch.Tensor: |
|
if self.position_embedding_type != PositionEmbeddingType.alibi: |
|
return None |
|
|
|
alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype) |
|
|
|
|
|
|
|
|
|
|
|
alibi_bias = alibi_bias.unsqueeze(2) |
|
if query_length != 1: |
|
alibi_bias = alibi_bias.expand(-1, -1, query_length, -1) |
|
|
|
|
|
|
|
|
|
|
|
return alibi_bias |
|
|
|
def _get_rope_cos_sin( |
|
self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device |
|
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: |
|
if self.position_embedding_type == PositionEmbeddingType.rope: |
|
cos, sin = self.rope(key_length, dtype=dtype, device=device) |
|
cos = cos[position_ids].unsqueeze(1) |
|
sin = sin[position_ids].unsqueeze(1) |
|
return cos, sin |
|
|
|
def _prepare_causal_attention_mask( |
|
self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device |
|
) -> torch.Tensor: |
|
past_length = key_length - query_length |
|
|
|
|
|
|
|
|
|
|
|
if query_length > 1: |
|
|
|
causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) |
|
causal_mask[:, past_length:] = torch.tril( |
|
torch.ones(query_length, query_length, dtype=torch.bool, device=device) |
|
) |
|
|
|
if past_length > 0: |
|
causal_mask[:, :past_length] = True |
|
|
|
|
|
causal_mask = causal_mask.unsqueeze(0) |
|
|
|
if attention_mask is None: |
|
|
|
causal_mask = causal_mask.expand(batch_size, -1, -1) |
|
else: |
|
|
|
causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool) |
|
else: |
|
if attention_mask is None: |
|
|
|
causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) |
|
else: |
|
|
|
causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) |
|
|
|
|
|
|
|
|
|
|
|
causal_mask = causal_mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
return causal_mask |
|
|
|
def _get_initial_hidden_state( |
|
self, |
|
input_ids: torch.Tensor, |
|
inputs_embeds: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
token_type_ids: torch.Tensor, |
|
) -> torch.Tensor: |
|
if inputs_embeds is None: |
|
inputs_embeds = self.wte(input_ids) |
|
|
|
if self.position_embedding_type == PositionEmbeddingType.learned_absolute: |
|
inputs_embeds = inputs_embeds + self.wpe(position_ids) |
|
|
|
if token_type_ids is not None: |
|
inputs_embeds = inputs_embeds + self.wte(token_type_ids) |
|
|
|
inputs_embeds = self.drop(inputs_embeds) |
|
|
|
return inputs_embeds |
|
|
|
def _prepare_a_bunch_of_stuff( |
|
self, |
|
input_ids: torch.Tensor, |
|
past_key_values: DynamicCache, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
inputs_embeds: torch.Tensor, |
|
use_cache: bool, |
|
output_hidden_states: bool, |
|
return_dict: bool, |
|
) -> Tuple[ |
|
bool, |
|
bool, |
|
bool, |
|
torch.Size, |
|
torch.Tensor, |
|
torch.Tensor, |
|
torch.Tensor, |
|
Optional[Tuple[torch.Tensor, torch.Tensor]], |
|
DynamicCache, |
|
]: |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
use_cache = self.config.use_cache if use_cache is None else use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
|
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
batch_size = input_shape[0] |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if self.position_embedding_type == PositionEmbeddingType.alibi: |
|
if position_ids is not None: |
|
warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning) |
|
|
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_length = 0 if past_key_values is None else past_key_values.get_seq_length() |
|
query_length = input_shape[-1] |
|
key_length = past_length + query_length |
|
|
|
if position_ids is None: |
|
position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids) |
|
|
|
|
|
|
|
|
|
|
|
alibi_bias = self._get_alibi_bias( |
|
attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype |
|
) |
|
|
|
|
|
|
|
|
|
|
|
rope_cos_sin = self._get_rope_cos_sin( |
|
key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._use_flash_attention_2: |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
elif self._use_sdpa: |
|
|
|
if attention_mask is not None: |
|
attention_mask = self._prepare_causal_attention_mask( |
|
attention_mask, batch_size, query_length, key_length, device |
|
) |
|
|
|
attention_mask = torch.where( |
|
attention_mask, |
|
~attention_mask if alibi_bias is None else alibi_bias, |
|
self._get_mask_value(attention_mask.device, hidden_states.dtype), |
|
) |
|
else: |
|
attention_mask = self._prepare_causal_attention_mask( |
|
attention_mask, batch_size, query_length, key_length, device |
|
) |
|
|
|
attention_mask = torch.where( |
|
attention_mask, |
|
~attention_mask if alibi_bias is None else alibi_bias, |
|
self._get_mask_value(attention_mask.device, hidden_states.dtype), |
|
) |
|
|
|
return ( |
|
output_hidden_states, |
|
use_cache, |
|
return_dict, |
|
input_shape, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
rope_cos_sin, |
|
past_key_values, |
|
) |
|
|
|
def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
|
|
|
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: |
|
self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device) |
|
return self.mask_value |
|
|
|
|
|
class GraniteForCausalLM(GranitePreTrainedModel): |
|
_keys_to_ignore_on_load_missing = ["lm_head.weight"] |
|
|
|
def __init__(self, config: GraniteConfig, **kwargs) -> None: |
|
super().__init__(config, **kwargs) |
|
self.transformer = GraniteModel(config, **kwargs) |
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
return self.transformer.wte |
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None: |
|
self.transformer.wte = value |
|
|
|
def get_output_embeddings(self) -> nn.Linear: |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
|
self.lm_head = new_embeddings |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.Tensor, |
|
past_key_values: Optional[DynamicCache] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> dict: |
|
token_type_ids = kwargs.get("token_type_ids", None) |
|
|
|
if past_key_values: |
|
past_length = past_key_values.get_seq_length() |
|
|
|
|
|
if input_ids.shape[1] > past_length: |
|
remove_prefix_length = past_length |
|
else: |
|
|
|
remove_prefix_length = input_ids.shape[1] - 1 |
|
|
|
input_ids = input_ids[:, remove_prefix_length:] |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids[:, -input_ids.shape[1] :] |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
position_ids = kwargs.get("position_ids", None) |
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 0) |
|
if past_key_values: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
else: |
|
position_ids = None |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"position_ids": position_ids, |
|
"attention_mask": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
} |
|
) |
|
return model_inputs |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[Union[torch.Tensor]] = None, |
|
past_key_values: Optional[DynamicCache] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[Union[torch.Tensor]] = None, |
|
position_ids: Optional[Union[torch.Tensor]] = None, |
|
inputs_embeds: Optional[Union[torch.Tensor]] = None, |
|
labels: Optional[Union[torch.Tensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformer_outputs = self.transformer( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
|
|
if labels is not None: |
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|