File size: 29,428 Bytes
c9e4fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 |
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# Copyright 2023 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, Tri Dao.
import copy
import math
import warnings
from typing import Optional, Union, List
import torch
import torch.nn as nn
from .bert_padding import unpad_input, pad_input
from .activation import get_act_fn
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
from .mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer
from .configuration_bert import FlexBertConfig, maybe_add_padding
from .normalization import get_norm_layer
from .initialization import ModuleType, init_weights
class BertAlibiLayer(nn.Module):
"""Composes the Mosaic BERT attention and FFN blocks into a single layer."""
def __init__(self, config):
super().__init__()
self.attention = BertAlibiUnpadAttention(config)
self.mlp = BertResidualGLU(config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
seqlen: int,
subset_idx: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (total_nnz, dim)
cu_seqlens: (batch + 1,)
seqlen: int
subset_idx: () set of indices whose values we care about at the end of the layer
(e.g., the masked tokens, if this is the final layer).
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen_in_batch)
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: None or (batch, heads) or (heads,)
"""
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
attention_output = self.attention(
hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes
)
layer_output = self.mlp(attention_output)
return layer_output
class BertAlibiEncoder(nn.Module):
"""A stack of BERT layers providing the backbone of Mosaic BERT.
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
but with substantial modifications to implement unpadding and ALiBi.
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
at padded tokens, and pre-computes attention biases to implement ALiBi.
"""
def __init__(self, config):
super().__init__()
layer = BertAlibiLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.num_attention_heads = config.num_attention_heads
# The alibi mask will be dynamically expanded if it is too small for
# the input the model receives. But it generally helps to initialize it
# to a reasonably large size to help pre-allocate CUDA memory.
# The default `alibi_starting_size` is 512.
self._current_alibi_size = int(config.alibi_starting_size)
self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size))
self.rebuild_alibi_tensor(size=config.alibi_starting_size)
def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None):
# Alibi
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
# In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
# of the logits, which makes the math work out *after* applying causal masking. If no causal masking
# will be applied, it is necessary to construct the diagonal mask.
n_heads = self.num_attention_heads
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
def get_slopes_power_of_2(n_heads: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
ratio = start
return [start * ratio**i for i in range(n_heads)]
# In the paper, they only train models that have 2^a heads for some a. This function
# has some good properties that only occur when the input is a power of 2. To
# maintain that even when the number of heads is not a power of 2, we use a
# workaround.
if math.log2(n_heads).is_integer():
return get_slopes_power_of_2(n_heads)
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
slopes_a = get_slopes_power_of_2(closest_power_of_2)
slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
return slopes_a + slopes_b
context_position = torch.arange(size, device=device)[:, None]
memory_position = torch.arange(size, device=device)[None, :]
relative_position = torch.abs(memory_position - context_position)
# [n_heads, max_token_length, max_token_length]
relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
self.slopes = slopes
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
# [1, n_heads, max_token_length, max_token_length]
alibi = alibi.unsqueeze(0)
assert alibi.shape == torch.Size([1, n_heads, size, size])
self._current_alibi_size = size
self.alibi = alibi
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_all_encoded_layers: Optional[bool] = True,
subset_mask: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
attention_mask_bool = attention_mask.bool()
batch, seqlen = hidden_states.shape[:2]
# Unpad inputs and mask. It will remove tokens that are padded.
# Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens.
# Then unpadding performs the following compression of the inputs:
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool)
# Add alibi matrix to extended_attention_mask
if self._current_alibi_size < seqlen:
# Rebuild the alibi tensor when needed
warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}")
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
elif self.alibi.device != hidden_states.device:
# Device catch-up
self.alibi = self.alibi.to(hidden_states.device)
self.slopes = self.slopes.to(hidden_states.device) # type: ignore
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
alibi_attn_mask = attn_bias + alibi_bias
all_encoder_layers = []
if subset_mask is None:
for layer_module in self.layer:
hidden_states = layer_module(
hidden_states,
cu_seqlens,
seqlen,
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask,
slopes=self.slopes,
)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
# Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens.
# Then padding performs the following de-compression:
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
else:
for i in range(len(self.layer) - 1):
layer_module = self.layer[i]
hidden_states = layer_module(
hidden_states,
cu_seqlens,
seqlen,
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask,
slopes=self.slopes,
)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten()
hidden_states = self.layer[-1](
hidden_states,
cu_seqlens,
seqlen,
subset_idx=subset_idx,
indices=indices,
attn_mask=attention_mask,
bias=alibi_attn_mask,
slopes=self.slopes,
)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = get_act_fn(config.head_pred_act)
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = get_norm_layer(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class FlexBertLayerBase(nn.Module):
"""A FlexBERT Layer base class for type hints."""
attn: FlexBertAttentionBase
mlp: FlexBertMLPBase
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.layer_id = layer_id
def _init_weights(self, reset_params: bool = False):
if hasattr(self, "attn"):
self.attn._init_weights(reset_params)
if hasattr(self, "mlp"):
self.mlp._init_weights(reset_params)
def reset_parameters(self):
self._init_weights(reset_params=True)
def forward(self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
raise NotImplementedError("This is a base class and should not be used directly.")
class FlexBertCompileUnpadPreNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = get_norm_layer(config)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.mlp_norm = get_norm_layer(config, compiled_norm=config.compile_model)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
self.compile_model = config.compile_model
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.attn_norm.reset_parameters()
self.mlp_norm.reset_parameters()
@torch.compile(dynamic=True)
def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.mlp(self.mlp_norm(hidden_states))
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (total_nnz, dim)
cu_seqlens: (batch + 1,)
max_seqlen: int
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen)
"""
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
return attn_out + self.compiled_mlp(attn_out)
class FlexBertUnpadPreNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = get_norm_layer(config)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.mlp_norm = get_norm_layer(config)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.attn_norm.reset_parameters()
self.mlp_norm.reset_parameters()
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (total_nnz, dim)
cu_seqlens: (batch + 1,)
max_seqlen: int
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen)
"""
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
return attn_out + self.mlp(self.mlp_norm(attn_out))
class FlexBertUnpadParallelPreNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT parallel attention and MLP blocks into a single layer using pre-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.attn_size = config.hidden_size * 3
self.mlp_size = config.intermediate_size * 2
# Compute QKV and FF outputs at once
self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
self.norm = nn.Identity()
else:
self.norm = get_norm_layer(config)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params and hasattr(self.norm, "reset_parameters"):
self.norm.reset_parameters()
init_weights(
self.config,
self.Wqkvff,
layer_dim=self.config.hidden_size,
layer_id=None,
type_of_module=ModuleType.in_module,
)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (total_nnz, dim)
attn_mask: None or (batch, max_seqlen)
"""
# Compute QKV and FF outputs at once and split them
qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=1)
return hidden_states + self.attn(qkv, cu_seqlens, max_seqlen, indices, attn_mask) + self.mlp(intermediate_ff)
class FlexBertPaddedPreNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = get_norm_layer(config)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.mlp_norm = get_norm_layer(config)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.attn_norm.reset_parameters()
self.mlp_norm.reset_parameters()
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (batch, max_seqlen, dim)
attn_mask: None or (batch, max_seqlen)
"""
attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), attn_mask)
return attn_out + self.mlp(self.mlp_norm(attn_out))
class FlexBertPaddedParallelPreNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.attn_size = config.hidden_size * 3
self.mlp_size = config.intermediate_size * 2
# Compute QKV and FF outputs at once
self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
self.norm = nn.Identity()
else:
self.norm = get_norm_layer(config)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.norm.reset_parameters()
init_weights(
self.config,
self.Wqkvff,
layer_dim=self.config.hidden_size,
layer_id=None,
type_of_module=ModuleType.in_module,
)
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (batch, max_seqlen, dim)
attn_mask: None or (batch, max_seqlen)
"""
# Compute QKV and FF outputs at once and split them
qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=2)
return hidden_states + self.attn(qkv, attn_mask) + self.mlp(intermediate_ff)
class FlexBertUnpadPostNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.attn_norm = get_norm_layer(config)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
self.mlp_norm = get_norm_layer(config)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.attn_norm.reset_parameters()
self.mlp_norm.reset_parameters()
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (total_nnz, dim)
cu_seqlens: (batch + 1,)
max_seqlen: int
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen)
"""
attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, cu_seqlens, max_seqlen, indices, attn_mask))
return self.mlp_norm(attn_out + self.mlp(attn_out))
class FlexBertPaddedPostNormLayer(FlexBertLayerBase):
"""Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.attn = get_attention_layer(config, layer_id=layer_id)
self.attn_norm = get_norm_layer(config)
self.mlp = get_mlp_layer(config, layer_id=layer_id)
self.mlp_norm = get_norm_layer(config)
def _init_weights(self, reset_params: bool = False):
super()._init_weights(reset_params)
if reset_params:
self.mlp_norm.reset_parameters()
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Args:
hidden_states: (batch, max_seqlen, dim)
attn_mask: None or (batch, max_seqlen)
"""
attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, attn_mask))
return self.mlp_norm(attn_out + self.mlp(attn_out))
LAYER2CLS = {
"unpadded_prenorm": FlexBertUnpadPreNormLayer,
"unpadded_compile_prenorm": FlexBertCompileUnpadPreNormLayer,
"unpadded_parallel_prenorm": FlexBertUnpadParallelPreNormLayer,
"unpadded_postnorm": FlexBertUnpadPostNormLayer,
"padded_prenorm": FlexBertPaddedPreNormLayer,
"padded_parallel_prenorm": FlexBertPaddedParallelPreNormLayer,
"padded_postnorm": FlexBertPaddedPostNormLayer,
}
def get_bert_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertLayerBase:
try:
bert_layer = (
config.initial_bert_layer
if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None
else config.bert_layer
)
bert_layer = maybe_add_padding(config, bert_layer)
if config.compile_model and bert_layer == "unpadded_prenorm":
bert_layer = "unpadded_compile_prenorm"
return LAYER2CLS[bert_layer](config, layer_id=layer_id)
except KeyError:
if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None:
raise ValueError(
f"Invalid BERT layer type: {config.initial_bert_layer=}, must be one of {LAYER2CLS.keys()}."
f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
)
else:
raise ValueError(
f"Invalid BERT layer type: {config.bert_layer=}, must be one of {LAYER2CLS.keys()}. "
f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
)
class FlexBertEncoderBase(nn.Module):
"""A FlexBERT base class for type hints."""
layers: nn.ModuleList
def _init_weights(self, reset_params: bool = False):
if hasattr(self, "layers"):
for layer in self.layers:
layer._init_weights(reset_params=reset_params)
def reset_parameters(self):
self._init_weights(reset_params=True)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("This is a base class and should not be used directly.")
class FlexBertUnpadEncoder(FlexBertEncoderBase):
"""A stack of BERT layers providing the backbone of FlexBERT.
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
but with substantial modifications to implement unpadding and ALiBi.
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
at padded tokens, and pre-computes attention biases to implement ALiBi.
"""
def __init__(self, config: FlexBertConfig):
super().__init__()
self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
self.num_attention_heads = config.num_attention_heads
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.Tensor:
if indices is None and cu_seqlens is None and max_seqlen is None:
attention_mask_bool = attention_mask.bool()
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen = unpad_input(
hidden_states, attention_mask_bool
)
for layer_module in self.layers:
hidden_states = layer_module(
hidden_states,
cu_seqlens,
max_seqlen,
indices,
attn_mask=attention_mask,
)
return pad_input(hidden_states, indices, batch, seqlen)
else:
for layer_module in self.layers:
hidden_states = layer_module(
hidden_states,
cu_seqlens,
max_seqlen,
indices,
attn_mask=attention_mask,
)
return hidden_states
class FlexBertPaddedEncoder(FlexBertEncoderBase):
"""A stack of BERT layers providing the backbone of FlexBERT.
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
but with substantial modifications to implement unpadding and ALiBi.
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
at padded tokens, and pre-computes attention biases to implement ALiBi.
"""
def __init__(self, config: FlexBertConfig):
super().__init__()
self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
self.num_attention_heads = config.num_attention_heads
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, attn_mask=attention_mask)
return hidden_states
ENC2CLS = {
"unpadded_base": FlexBertUnpadEncoder,
"padded_base": FlexBertPaddedEncoder,
}
def get_encoder_layer(config: FlexBertConfig) -> FlexBertEncoderBase:
try:
return ENC2CLS[maybe_add_padding(config, config.encoder_layer)](config)
except KeyError:
raise ValueError(
f"Invalid encoder layer type: {config.encoder_layer=}, must be one of {ENC2CLS.keys()}. "
f"{config.padding=} will be automatically prepended to `config.encoder_layer` if unspecified."
)
|