|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import math |
|
import warnings |
|
from typing import Optional, Union, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .bert_padding import unpad_input, pad_input |
|
|
|
from .activation import get_act_fn |
|
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer |
|
from .mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer |
|
from .configuration_bert import FlexBertConfig, maybe_add_padding |
|
from .normalization import get_norm_layer |
|
from .initialization import ModuleType, init_weights |
|
|
|
|
|
class BertAlibiLayer(nn.Module): |
|
"""Composes the Mosaic BERT attention and FFN blocks into a single layer.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.attention = BertAlibiUnpadAttention(config) |
|
self.mlp = BertResidualGLU(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
seqlen: int, |
|
subset_idx: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
bias: Optional[torch.Tensor] = None, |
|
slopes: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
seqlen: int |
|
subset_idx: () set of indices whose values we care about at the end of the layer |
|
(e.g., the masked tokens, if this is the final layer). |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen_in_batch) |
|
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) |
|
slopes: None or (batch, heads) or (heads,) |
|
""" |
|
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" |
|
attention_output = self.attention( |
|
hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes |
|
) |
|
layer_output = self.mlp(attention_output) |
|
return layer_output |
|
|
|
|
|
class BertAlibiEncoder(nn.Module): |
|
"""A stack of BERT layers providing the backbone of Mosaic BERT. |
|
|
|
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`, |
|
but with substantial modifications to implement unpadding and ALiBi. |
|
|
|
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation |
|
at padded tokens, and pre-computes attention biases to implement ALiBi. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
layer = BertAlibiLayer(config) |
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
|
|
|
|
|
|
|
|
|
|
self._current_alibi_size = int(config.alibi_starting_size) |
|
self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size)) |
|
self.rebuild_alibi_tensor(size=config.alibi_starting_size) |
|
|
|
def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None): |
|
|
|
|
|
|
|
|
|
|
|
n_heads = self.num_attention_heads |
|
|
|
def _get_alibi_head_slopes(n_heads: int) -> List[float]: |
|
def get_slopes_power_of_2(n_heads: int) -> List[float]: |
|
start = 2 ** (-(2 ** -(math.log2(n_heads) - 3))) |
|
ratio = start |
|
return [start * ratio**i for i in range(n_heads)] |
|
|
|
|
|
|
|
|
|
|
|
if math.log2(n_heads).is_integer(): |
|
return get_slopes_power_of_2(n_heads) |
|
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) |
|
slopes_a = get_slopes_power_of_2(closest_power_of_2) |
|
slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2) |
|
slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2] |
|
return slopes_a + slopes_b |
|
|
|
context_position = torch.arange(size, device=device)[:, None] |
|
memory_position = torch.arange(size, device=device)[None, :] |
|
relative_position = torch.abs(memory_position - context_position) |
|
|
|
relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1) |
|
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) |
|
self.slopes = slopes |
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position |
|
|
|
alibi = alibi.unsqueeze(0) |
|
assert alibi.shape == torch.Size([1, n_heads, size, size]) |
|
|
|
self._current_alibi_size = size |
|
self.alibi = alibi |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
output_all_encoded_layers: Optional[bool] = True, |
|
subset_mask: Optional[torch.Tensor] = None, |
|
) -> List[torch.Tensor]: |
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
attention_mask_bool = attention_mask.bool() |
|
batch, seqlen = hidden_states.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool) |
|
|
|
|
|
if self._current_alibi_size < seqlen: |
|
|
|
warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}") |
|
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device) |
|
elif self.alibi.device != hidden_states.device: |
|
|
|
self.alibi = self.alibi.to(hidden_states.device) |
|
self.slopes = self.slopes.to(hidden_states.device) |
|
alibi_bias = self.alibi[:, :, :seqlen, :seqlen] |
|
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] |
|
alibi_attn_mask = attn_bias + alibi_bias |
|
|
|
all_encoder_layers = [] |
|
if subset_mask is None: |
|
for layer_module in self.layer: |
|
hidden_states = layer_module( |
|
hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
None, |
|
indices, |
|
attn_mask=attention_mask, |
|
bias=alibi_attn_mask, |
|
slopes=self.slopes, |
|
) |
|
if output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = pad_input(hidden_states, indices, batch, seqlen) |
|
else: |
|
for i in range(len(self.layer) - 1): |
|
layer_module = self.layer[i] |
|
hidden_states = layer_module( |
|
hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
None, |
|
indices, |
|
attn_mask=attention_mask, |
|
bias=alibi_attn_mask, |
|
slopes=self.slopes, |
|
) |
|
if output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten() |
|
hidden_states = self.layer[-1]( |
|
hidden_states, |
|
cu_seqlens, |
|
seqlen, |
|
subset_idx=subset_idx, |
|
indices=indices, |
|
attn_mask=attention_mask, |
|
bias=alibi_attn_mask, |
|
slopes=self.slopes, |
|
) |
|
|
|
if not output_all_encoded_layers: |
|
all_encoder_layers.append(hidden_states) |
|
return all_encoder_layers |
|
|
|
|
|
class BertPooler(nn.Module): |
|
def __init__(self, config): |
|
super(BertPooler, self).__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor: |
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] if pool else hidden_states |
|
pooled_output = self.dense(first_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
if isinstance(config.hidden_act, str): |
|
self.transform_act_fn = get_act_fn(config.head_pred_act) |
|
else: |
|
self.transform_act_fn = config.hidden_act |
|
self.LayerNorm = get_norm_layer(config) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.dense(hidden_states) |
|
hidden_states = self.transform_act_fn(hidden_states) |
|
hidden_states = self.LayerNorm(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class FlexBertLayerBase(nn.Module): |
|
"""A FlexBERT Layer base class for type hints.""" |
|
|
|
attn: FlexBertAttentionBase |
|
mlp: FlexBertMLPBase |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_id = layer_id |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
if hasattr(self, "attn"): |
|
self.attn._init_weights(reset_params) |
|
if hasattr(self, "mlp"): |
|
self.mlp._init_weights(reset_params) |
|
|
|
def reset_parameters(self): |
|
self._init_weights(reset_params=True) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
|
|
class FlexBertCompileUnpadPreNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
if config.skip_first_prenorm and config.embed_norm and layer_id == 0: |
|
self.attn_norm = nn.Identity() |
|
else: |
|
self.attn_norm = get_norm_layer(config) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.mlp_norm = get_norm_layer(config, compiled_norm=config.compile_model) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
self.compile_model = config.compile_model |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.attn_norm.reset_parameters() |
|
self.mlp_norm.reset_parameters() |
|
|
|
@torch.compile(dynamic=True) |
|
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
return self.mlp(self.mlp_norm(hidden_states)) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
max_seqlen: int, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
max_seqlen: int |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask) |
|
return attn_out + self.compiled_mlp(attn_out) |
|
|
|
|
|
class FlexBertUnpadPreNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
if config.skip_first_prenorm and config.embed_norm and layer_id == 0: |
|
self.attn_norm = nn.Identity() |
|
else: |
|
self.attn_norm = get_norm_layer(config) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.mlp_norm = get_norm_layer(config) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.attn_norm.reset_parameters() |
|
self.mlp_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
max_seqlen: int, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
max_seqlen: int |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask) |
|
return attn_out + self.mlp(self.mlp_norm(attn_out)) |
|
|
|
|
|
class FlexBertUnpadParallelPreNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT parallel attention and MLP blocks into a single layer using pre-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.attn_size = config.hidden_size * 3 |
|
self.mlp_size = config.intermediate_size * 2 |
|
|
|
self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias) |
|
if config.skip_first_prenorm and config.embed_norm and layer_id == 0: |
|
self.norm = nn.Identity() |
|
else: |
|
self.norm = get_norm_layer(config) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params and hasattr(self.norm, "reset_parameters"): |
|
self.norm.reset_parameters() |
|
|
|
init_weights( |
|
self.config, |
|
self.Wqkvff, |
|
layer_dim=self.config.hidden_size, |
|
layer_id=None, |
|
type_of_module=ModuleType.in_module, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
max_seqlen: int, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
|
|
qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=1) |
|
return hidden_states + self.attn(qkv, cu_seqlens, max_seqlen, indices, attn_mask) + self.mlp(intermediate_ff) |
|
|
|
|
|
class FlexBertPaddedPreNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
if config.skip_first_prenorm and config.embed_norm and layer_id == 0: |
|
self.attn_norm = nn.Identity() |
|
else: |
|
self.attn_norm = get_norm_layer(config) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.mlp_norm = get_norm_layer(config) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.attn_norm.reset_parameters() |
|
self.mlp_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (batch, max_seqlen, dim) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), attn_mask) |
|
return attn_out + self.mlp(self.mlp_norm(attn_out)) |
|
|
|
|
|
class FlexBertPaddedParallelPreNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.attn_size = config.hidden_size * 3 |
|
self.mlp_size = config.intermediate_size * 2 |
|
|
|
self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias) |
|
if config.skip_first_prenorm and config.embed_norm and layer_id == 0: |
|
self.norm = nn.Identity() |
|
else: |
|
self.norm = get_norm_layer(config) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.norm.reset_parameters() |
|
|
|
init_weights( |
|
self.config, |
|
self.Wqkvff, |
|
layer_dim=self.config.hidden_size, |
|
layer_id=None, |
|
type_of_module=ModuleType.in_module, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (batch, max_seqlen, dim) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
|
|
qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=2) |
|
return hidden_states + self.attn(qkv, attn_mask) + self.mlp(intermediate_ff) |
|
|
|
|
|
class FlexBertUnpadPostNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.attn_norm = get_norm_layer(config) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
self.mlp_norm = get_norm_layer(config) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.attn_norm.reset_parameters() |
|
self.mlp_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
max_seqlen: int, |
|
indices: Optional[torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (total_nnz, dim) |
|
cu_seqlens: (batch + 1,) |
|
max_seqlen: int |
|
indices: None or (total_nnz,) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, cu_seqlens, max_seqlen, indices, attn_mask)) |
|
return self.mlp_norm(attn_out + self.mlp(attn_out)) |
|
|
|
|
|
class FlexBertPaddedPostNormLayer(FlexBertLayerBase): |
|
"""Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization.""" |
|
|
|
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
|
super().__init__(config=config, layer_id=layer_id) |
|
self.attn = get_attention_layer(config, layer_id=layer_id) |
|
self.attn_norm = get_norm_layer(config) |
|
self.mlp = get_mlp_layer(config, layer_id=layer_id) |
|
self.mlp_norm = get_norm_layer(config) |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
super()._init_weights(reset_params) |
|
if reset_params: |
|
self.mlp_norm.reset_parameters() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for a BERT layer, including both attention and MLP. |
|
|
|
Args: |
|
hidden_states: (batch, max_seqlen, dim) |
|
attn_mask: None or (batch, max_seqlen) |
|
""" |
|
attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, attn_mask)) |
|
return self.mlp_norm(attn_out + self.mlp(attn_out)) |
|
|
|
|
|
LAYER2CLS = { |
|
"unpadded_prenorm": FlexBertUnpadPreNormLayer, |
|
"unpadded_compile_prenorm": FlexBertCompileUnpadPreNormLayer, |
|
"unpadded_parallel_prenorm": FlexBertUnpadParallelPreNormLayer, |
|
"unpadded_postnorm": FlexBertUnpadPostNormLayer, |
|
"padded_prenorm": FlexBertPaddedPreNormLayer, |
|
"padded_parallel_prenorm": FlexBertPaddedParallelPreNormLayer, |
|
"padded_postnorm": FlexBertPaddedPostNormLayer, |
|
} |
|
|
|
|
|
def get_bert_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertLayerBase: |
|
try: |
|
bert_layer = ( |
|
config.initial_bert_layer |
|
if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None |
|
else config.bert_layer |
|
) |
|
bert_layer = maybe_add_padding(config, bert_layer) |
|
if config.compile_model and bert_layer == "unpadded_prenorm": |
|
bert_layer = "unpadded_compile_prenorm" |
|
return LAYER2CLS[bert_layer](config, layer_id=layer_id) |
|
except KeyError: |
|
if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None: |
|
raise ValueError( |
|
f"Invalid BERT layer type: {config.initial_bert_layer=}, must be one of {LAYER2CLS.keys()}." |
|
f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified." |
|
) |
|
else: |
|
raise ValueError( |
|
f"Invalid BERT layer type: {config.bert_layer=}, must be one of {LAYER2CLS.keys()}. " |
|
f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified." |
|
) |
|
|
|
|
|
class FlexBertEncoderBase(nn.Module): |
|
"""A FlexBERT base class for type hints.""" |
|
|
|
layers: nn.ModuleList |
|
|
|
def _init_weights(self, reset_params: bool = False): |
|
if hasattr(self, "layers"): |
|
for layer in self.layers: |
|
layer._init_weights(reset_params=reset_params) |
|
|
|
def reset_parameters(self): |
|
self._init_weights(reset_params=True) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
raise NotImplementedError("This is a base class and should not be used directly.") |
|
|
|
|
|
class FlexBertUnpadEncoder(FlexBertEncoderBase): |
|
"""A stack of BERT layers providing the backbone of FlexBERT. |
|
|
|
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`, |
|
but with substantial modifications to implement unpadding and ALiBi. |
|
|
|
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation |
|
at padded tokens, and pre-computes attention biases to implement ALiBi. |
|
""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__() |
|
self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)]) |
|
self.num_attention_heads = config.num_attention_heads |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
indices: Optional[torch.Tensor] = None, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
) -> torch.Tensor: |
|
if indices is None and cu_seqlens is None and max_seqlen is None: |
|
attention_mask_bool = attention_mask.bool() |
|
batch, seqlen = hidden_states.shape[:2] |
|
hidden_states, indices, cu_seqlens, max_seqlen = unpad_input( |
|
hidden_states, attention_mask_bool |
|
) |
|
|
|
for layer_module in self.layers: |
|
hidden_states = layer_module( |
|
hidden_states, |
|
cu_seqlens, |
|
max_seqlen, |
|
indices, |
|
attn_mask=attention_mask, |
|
) |
|
|
|
return pad_input(hidden_states, indices, batch, seqlen) |
|
else: |
|
for layer_module in self.layers: |
|
hidden_states = layer_module( |
|
hidden_states, |
|
cu_seqlens, |
|
max_seqlen, |
|
indices, |
|
attn_mask=attention_mask, |
|
) |
|
return hidden_states |
|
|
|
|
|
class FlexBertPaddedEncoder(FlexBertEncoderBase): |
|
"""A stack of BERT layers providing the backbone of FlexBERT. |
|
|
|
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`, |
|
but with substantial modifications to implement unpadding and ALiBi. |
|
|
|
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation |
|
at padded tokens, and pre-computes attention biases to implement ALiBi. |
|
""" |
|
|
|
def __init__(self, config: FlexBertConfig): |
|
super().__init__() |
|
self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)]) |
|
self.num_attention_heads = config.num_attention_heads |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor: |
|
for layer_module in self.layers: |
|
hidden_states = layer_module(hidden_states, attn_mask=attention_mask) |
|
|
|
return hidden_states |
|
|
|
|
|
ENC2CLS = { |
|
"unpadded_base": FlexBertUnpadEncoder, |
|
"padded_base": FlexBertPaddedEncoder, |
|
} |
|
|
|
|
|
def get_encoder_layer(config: FlexBertConfig) -> FlexBertEncoderBase: |
|
try: |
|
return ENC2CLS[maybe_add_padding(config, config.encoder_layer)](config) |
|
except KeyError: |
|
raise ValueError( |
|
f"Invalid encoder layer type: {config.encoder_layer=}, must be one of {ENC2CLS.keys()}. " |
|
f"{config.padding=} will be automatically prepended to `config.encoder_layer` if unspecified." |
|
) |
|
|