| from __future__ import annotations |
|
|
| import numpy as np |
| import logging |
| import math |
| import sys |
| from abc import abstractmethod |
| from collections import defaultdict |
| from functools import partial |
| from typing import ( |
| Callable, |
| Dict, |
| Iterable, |
| List, |
| NamedTuple, |
| Optional, |
| Sequence, |
| Set, |
| Tuple, |
| cast, |
| ) |
| from dataclasses import fields |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.backends.cuda |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import einsum |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.auto import AutoModel |
| from transformers.cache_utils import Cache |
|
|
| from .configuration_llada import ( |
| LLaDAConfig, |
| StrEnum, |
| InitFnType, |
| ActivationType, |
| BlockType, |
| LayerNormType, |
| ModelConfig, |
| ActivationCheckpointingStrategy, |
| ) |
|
|
| if sys.version_info.minor > 8: |
| from collections.abc import MutableMapping |
| elif sys.version_info.minor == 8: |
| from typing import MutableMapping |
| else: |
| raise SystemExit("This script supports Python 3.8 or higher") |
|
|
| __all__ = [ |
| "LayerNormBase", |
| "LayerNorm", |
| "RMSLayerNorm", |
| "GemmaRMSLayerNorm", |
| "RotaryEmbedding", |
| "Activation", |
| "GELU", |
| "ReLU", |
| "SwiGLU", |
| "LLaDABlock", |
| "LLaDASequentialBlock", |
| "LLaDAModel", |
| "LLaDAOutput", |
| "LLaDAGenerateOutput", |
| ] |
|
|
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class ModuleType(StrEnum): |
| in_module = "in" |
| out_module = "out" |
| emb = "emb" |
| final_out = "final_out" |
|
|
|
|
| def init_weights( |
| config: ModelConfig, |
| module: Union[nn.Linear, nn.Embedding], |
| d: Optional[int] = None, |
| layer_id: Optional[int] = None, |
| std_factor: float = 1.0, |
| type_of_module: Optional[ModuleType] = None, |
| ) -> None: |
| """ |
| Initialize weights of a linear or embedding module. |
| |
| :param config: The model config. |
| :param module: The linear or embedding submodule to initialize. |
| :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions |
| for fused layers. |
| :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by |
| ``1 / sqrt(2 * (layer_id + 1))``. |
| """ |
| d = d if d is not None else config.d_model |
| if config.init_fn == InitFnType.normal: |
| std = config.init_std * std_factor |
| if config.init_cutoff_factor is not None: |
| cutoff_value = config.init_cutoff_factor * std |
| nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value) |
| else: |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| elif config.init_fn == InitFnType.mitchell: |
| std = std_factor / math.sqrt(d) |
| if layer_id is not None: |
| std = std / math.sqrt(2 * (layer_id + 1)) |
| nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std) |
| elif config.init_fn == InitFnType.kaiming_normal: |
| nn.init.kaiming_normal_(module.weight, nonlinearity="relu") |
| elif config.init_fn == InitFnType.fan_in: |
| std = std_factor / math.sqrt(d) |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| elif config.init_fn == InitFnType.full_megatron: |
| if type_of_module is None: |
| raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.") |
|
|
| cutoff_factor = config.init_cutoff_factor |
| if cutoff_factor is None: |
| cutoff_factor = 3 |
|
|
| if type_of_module == ModuleType.in_module: |
| |
| std = config.init_std |
| elif type_of_module == ModuleType.out_module: |
| |
| std = config.init_std / math.sqrt(2.0 * config.n_layers) |
| elif type_of_module == ModuleType.emb: |
| |
| |
| std = config.init_std |
| elif type_of_module == ModuleType.final_out: |
| |
| std = config.d_model**-0.5 |
| else: |
| raise RuntimeError(f"Unknown module type '{type_of_module}'") |
| nn.init.trunc_normal_( |
| module.weight, |
| mean=0.0, |
| std=std, |
| a=-cutoff_factor * std, |
| b=cutoff_factor * std, |
| ) |
| else: |
| raise NotImplementedError(config.init_fn) |
|
|
| if isinstance(module, nn.Linear): |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False): |
| with torch.no_grad(): |
| module.weight.div_(math.sqrt(2 * config.n_layers)) |
|
|
|
|
| def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): |
| """ |
| Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` |
| is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. |
| """ |
| if check_neg_inf: |
| x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) |
| if check_pos_inf: |
| x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) |
|
|
|
|
| def activation_checkpoint_function(cfg: ModelConfig): |
| preserve_rng_state = ( |
| (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0) |
| ) |
| from torch.utils.checkpoint import checkpoint |
|
|
| return partial( |
| checkpoint, |
| preserve_rng_state=preserve_rng_state, |
| use_reentrant=False, |
| ) |
|
|
|
|
| class BufferCache(dict, MutableMapping[str, torch.Tensor]): |
| """ |
| Cache for attention biases and other things that would normally be stored as buffers. |
| We avoid using buffers because we've run into various issues doing so with FSDP. |
| In general it appears the way FSDP handles buffers is not well-defined. |
| It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid |
| since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into |
| NaNs when they're synchronized due to casting or some other issue. |
| """ |
|
|
|
|
| def _non_meta_init_device(config: ModelConfig) -> torch.device: |
| if config.init_device is not None and config.init_device != "meta": |
| return torch.device(config.init_device) |
| else: |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| class Dropout(nn.Dropout): |
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| if self.p == 0.0: |
| return input |
| else: |
| return F.dropout(input, self.p, self.training, self.inplace) |
|
|
|
|
| class LayerNormBase(nn.Module): |
| def __init__( |
| self, |
| config: ModelConfig, |
| *, |
| size: Optional[int] = None, |
| elementwise_affine: Optional[bool] = True, |
| eps: float = 1e-05, |
| ): |
| super().__init__() |
| self.config = config |
| self.eps = eps |
| self.normalized_shape = (size or config.d_model,) |
| if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine): |
| self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) |
| use_bias = self.config.bias_for_layer_norm |
| if use_bias is None: |
| use_bias = self.config.include_bias |
| if use_bias: |
| self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) |
| else: |
| self.register_parameter("bias", None) |
| else: |
| self.register_parameter("bias", None) |
| self.register_parameter("weight", None) |
|
|
| @abstractmethod |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @classmethod |
| def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase: |
| if config.layer_norm_type == LayerNormType.default: |
| return LayerNorm(config, size=size, low_precision=False, **kwargs) |
| elif config.layer_norm_type == LayerNormType.low_precision: |
| return LayerNorm(config, size=size, low_precision=True, **kwargs) |
| elif config.layer_norm_type == LayerNormType.rms: |
| return RMSLayerNorm(config, size=size, **kwargs) |
| elif config.layer_norm_type == LayerNormType.gemma_rms: |
| return GemmaRMSLayerNorm(config, size=size, **kwargs) |
| else: |
| raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'") |
|
|
| def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
| |
| |
| |
| if tensor.device.type == "cuda" and torch.is_autocast_enabled(): |
| return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype()) |
| elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled(): |
| return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype()) |
| else: |
| return tensor |
|
|
| def reset_parameters(self): |
| if self.weight is not None: |
| torch.nn.init.ones_(self.weight) |
| if self.bias is not None: |
| torch.nn.init.zeros_(self.bias) |
|
|
|
|
| class LayerNorm(LayerNormBase): |
| """ |
| The default :class:`LayerNorm` implementation which can optionally run in low precision. |
| """ |
|
|
| def __init__( |
| self, |
| config: ModelConfig, |
| size: Optional[int] = None, |
| low_precision: bool = False, |
| elementwise_affine: Optional[bool] = None, |
| eps: float = 1e-05, |
| ): |
| super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps) |
| self.low_precision = low_precision |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.low_precision: |
| module_device = x.device |
| downcast_x = self._cast_if_autocast_enabled(x) |
| downcast_weight = ( |
| self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight |
| ) |
| downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias |
| with torch.autocast(enabled=False, device_type=module_device.type): |
| return F.layer_norm( |
| downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps |
| ) |
| else: |
| return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) |
|
|
|
|
| class RMSLayerNorm(LayerNormBase): |
| """ |
| RMS layer norm, a simplified :class:`LayerNorm` implementation |
| """ |
|
|
| def __init__( |
| self, |
| config: ModelConfig, |
| size: Optional[int] = None, |
| elementwise_affine: Optional[bool] = None, |
| eps: float = 1e-5, |
| ): |
| super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| with torch.autocast(enabled=False, device_type=x.device.type): |
| og_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| x = x.to(og_dtype) |
|
|
| if self.weight is not None: |
| if self.bias is not None: |
| return self.weight * x + self.bias |
| else: |
| return self.weight * x |
| else: |
| return x |
|
|
|
|
| class GemmaRMSLayerNorm(LayerNormBase): |
| """ |
| Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation |
| """ |
|
|
| def __init__( |
| self, |
| config: ModelConfig, |
| size: Optional[int] = None, |
| elementwise_affine: Optional[bool] = None, |
| eps: float = 1e-5, |
| ): |
| super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| with torch.autocast(enabled=False, device_type=x.device.type): |
| og_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| x = x.to(og_dtype) |
|
|
| if self.weight is not None: |
| if self.bias is not None: |
| return x * (1 + self.weight) + self.bias |
| else: |
| return x * (1 + self.weight) |
| else: |
| return x |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). |
| """ |
|
|
| def __init__(self, config: ModelConfig, cache: BufferCache, |
| n_heads=None, min_freq=np.pi, max_freq=np.pi*32, grid_frac=0.381, rev=True, tp=True, sqrt=True, |
| direction_spacing=math.pi * (math.sqrt(5) - 1) / 2, cutlast=None, linear=False, scale=(lambda x: x)): |
| super().__init__() |
| self.config = config |
| self.__cache = cache |
| |
| self.n_heads = n_heads or getattr(config, 'num_key_value_heads', None) or config.n_kv_heads |
| self.min_freq = min_freq |
| self.max_freq = max_freq |
| self.grid_frac = grid_frac |
| self.rev = rev |
| self.tp = tp |
| self.sqrt = sqrt |
| self.direction_spacing = direction_spacing |
| self.cutlast = cutlast |
| self.linear = linear |
| self.scale = scale |
| self.dim = config.d_model // config.n_heads |
| |
| self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config)) |
|
|
| def get_rotary_embedding(self, seq_len: int, device: torch.device, position_ids: torch.Tensor=None, dtype: torch.dtype=None) -> Tuple[torch.Tensor, torch.Tensor]: |
| if ( |
| (pos_sin := self.__cache.get("rope_pos_sin")) is not None |
| and (pos_cos := self.__cache.get("rope_pos_cos")) is not None |
| and pos_sin.shape[-2] >= seq_len |
| and pos_cos.shape[-2] >= seq_len |
| ): |
| if pos_sin.device != device: |
| pos_sin = pos_sin.to(device) |
| self.__cache["rope_pos_sin"] = pos_sin |
| if pos_cos.device != device: |
| pos_cos = pos_cos.to(device) |
| self.__cache["rope_pos_cos"] = pos_cos |
| return self.apply_position_ids(pos_sin, pos_cos, position_ids, dtype) |
|
|
| |
| exponents = [1, 1, 5, 5] |
| while np.prod(2 ** np.array(exponents)) < seq_len: |
| exponents[0] += 1 |
|
|
| |
| mgrid = list(torch.meshgrid(*map(torch.arange, 2 ** np.array(exponents)), indexing='ij')) |
| position_mapping = torch.stack(mgrid, axis=-1).flatten(0, -2).float() |
|
|
| |
| self.n_grid = int(self.dim // 2 * self.grid_frac) |
| frequencies = [(0, 0)] * self.n_grid |
| div = 2 |
| if self.sqrt: |
| exponents = np.array(exponents) * 2 |
| div = np.sqrt(2) |
| for tgt, exp in list(enumerate(exponents))[::-1][2:]: |
| for i in range(exp if tgt else self.cutlast or self.dim // 2 - len(frequencies)): |
| frequencies.append((np.pi / div ** i, tgt)) |
| vals, tgts = map(list, zip(*frequencies)) |
| frequencies = torch.zeros(len(mgrid), self.dim // 2, dtype=torch.float) |
| frequencies[tgts, torch.arange(len(tgts))] = torch.tensor(vals).float() |
|
|
| |
| if self.linear: omega_F = 1/(1/self.min_freq + (1/self.max_freq - 1/self.min_freq) * self.scale(torch.linspace(0, 1, self.n_grid, dtype=torch.float))) |
| else: omega_F = self.min_freq * (self.max_freq / self.min_freq) ** self.scale(torch.linspace(0, 1, self.n_grid, dtype=torch.float)) |
| if self.rev: omega_F = omega_F.flip(0) |
| if self.direction_spacing == 'eq': |
| phi_hF = torch.linspace(0, np.pi, self.n_grid+1, dtype=torch.float)[:-1].unsqueeze(0).expand(self.n_heads, self.n_grid) |
| phi_hF = phi_hF + torch.linspace(0, phi_hF[0, 1], self.n_heads+1, dtype=torch.float)[:-1, torch.newaxis] |
| np.random.seed(42) |
| phi_hF = torch.stack([phi_hF[j, torch.tensor(np.random.permutation(self.n_grid))] for j in range(self.n_heads)]) |
| else: phi_hF = torch.arange(self.n_heads * self.n_grid).reshape(self.n_heads, self.n_grid).float() * self.direction_spacing |
| directions_hF2 = torch.stack((torch.cos(phi_hF), torch.sin(phi_hF)), dim=-1) |
|
|
| if self.tp: directions_hF2 = directions_hF2.flip(-1) |
| freqs_hF2 = omega_F.unsqueeze(-1) * directions_hF2 |
|
|
| |
| frequencies = frequencies[torch.newaxis].repeat(self.n_heads, 1, 1) |
| frequencies[:, -2:, :self.n_grid] = freqs_hF2.swapaxes(-2, -1) / 32 |
|
|
| |
| emb = (position_mapping.float() @ frequencies).repeat(1, 1, 2) |
| with torch.autocast(device.type, enabled=False): |
| pos_cos = emb.cos().to(device=device, non_blocking=True) |
| pos_sin = emb.sin().to(device=device, non_blocking=True) |
| self.__cache["rope_pos_cos"] = pos_cos |
| self.__cache["rope_pos_sin"] = pos_sin |
|
|
| return self.apply_position_ids(pos_sin, pos_cos, position_ids, dtype) |
|
|
| @staticmethod |
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| B, nh, T, hs = x.size() |
| x = x.view(B, nh, T, 2, hs // 2) |
| x1, x2 = x.unbind(dim=-2) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def apply_position_ids(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, position_ids: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: |
| if position_ids is not None: |
| pos_cos = pos_cos.swapaxes(0, -2)[position_ids].swapaxes(1, -2).to(dtype) |
| pos_sin = pos_sin.swapaxes(0, -2)[position_ids].swapaxes(1, -2).to(dtype) |
| return pos_sin, pos_cos |
|
|
| @staticmethod |
| def apply_rotary_pos_emb(pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| return ((t * pos_cos) + (RotaryEmbedding.rotate_half(t) * pos_sin)).to(t.dtype) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, position_ids=None) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.config.rope_full_precision: |
| q_, k_ = q.float(), k.float() |
| else: |
| q_, k_ = q, k |
|
|
| with torch.autocast(q.device.type, enabled=False): |
| query_len, key_len = q_.shape[-2], k_.shape[-2] |
| pos_sin, pos_cos = self.get_rotary_embedding(position_ids.max()+1, q_.device, position_ids, q_.dtype) |
| q_ = RotaryEmbedding.apply_rotary_pos_emb( |
| pos_sin[:, :, key_len - query_len : key_len, :], |
| pos_cos[:, :, key_len - query_len : key_len, :], |
| q_, |
| ) |
| k_ = RotaryEmbedding.apply_rotary_pos_emb(pos_sin, pos_cos, k_) |
| return q_.type_as(q), k_.type_as(k) |
|
|
|
|
| class Activation(nn.Module): |
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| @abstractmethod |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| @property |
| @abstractmethod |
| def output_multiplier(self) -> float: |
| raise NotImplementedError |
|
|
| @classmethod |
| def build(cls, config: ModelConfig) -> Activation: |
| if config.activation_type == ActivationType.gelu: |
| return cast(Activation, GELU(approximate="none")) |
| elif config.activation_type == ActivationType.relu: |
| return cast(Activation, ReLU(inplace=False)) |
| elif config.activation_type == ActivationType.silu: |
| return cast(Activation, SiLU(inplace=False)) |
| elif config.activation_type == ActivationType.swiglu: |
| return SwiGLU(config) |
| else: |
| raise NotImplementedError(f"Unknown activation: '{config.activation_type}'") |
|
|
|
|
| class GELU(nn.GELU): |
| @property |
| def output_multiplier(self) -> float: |
| return 1.0 |
|
|
|
|
| class ReLU(nn.ReLU): |
| @property |
| def output_multiplier(self) -> float: |
| return 1.0 |
|
|
| class SiLU(nn.SiLU): |
| @property |
| def output_multiplier(self) -> float: |
| return 1.0 |
|
|
| class SwiGLU(Activation): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x, gate = x.chunk(2, dim=-1) |
| return F.silu(gate) * x |
|
|
| @property |
| def output_multiplier(self) -> float: |
| return 0.5 |
|
|
|
|
| def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: |
| att_bias = torch.triu( |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.float), |
| diagonal=1, |
| ) |
| att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) |
| return att_bias.view(1, 1, seq_len, seq_len) |
|
|
|
|
| def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: |
| if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: |
| if causal_bias.device != device: |
| causal_bias = causal_bias.to(device) |
| cache["causal_attention_bias"] = causal_bias |
| return causal_bias |
| with torch.autocast(device.type, enabled=False): |
| causal_bias = causal_attention_bias(seq_len, device) |
| cache["causal_attention_bias"] = causal_bias |
| return causal_bias |
|
|
|
|
| def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: |
| alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) |
|
|
| |
| alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) |
| alibi_bias.abs_().mul_(-1) |
|
|
| |
| m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) |
| m.mul_(config.alibi_bias_max / config.n_heads) |
|
|
| |
| return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) |
|
|
|
|
| class LLaDABlock(nn.Module): |
| """ |
| A base class for transformer block implementations. |
| """ |
| ATTN_INFO_SHOWN = False |
|
|
| def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): |
| super().__init__() |
| self.layer_id = layer_id |
| self.config = config |
| self.hidden_size = ( |
| config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model |
| ) |
| self.__cache = cache |
| assert config.d_model % config.n_heads == 0 |
|
|
| self._activation_checkpoint_fn = None |
|
|
| |
| self.dropout = Dropout(config.residual_dropout) |
|
|
| |
| self.k_norm: Optional[LayerNormBase] = None |
| self.q_norm: Optional[LayerNormBase] = None |
| if config.attention_layer_norm: |
| self.k_norm = LayerNormBase.build( |
| config, |
| size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, |
| elementwise_affine=config.attention_layer_norm_with_affine, |
| ) |
| self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) |
|
|
| |
| self.act = Activation.build(config) |
| assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 |
|
|
| |
| self.attn_out = nn.Linear( |
| config.d_model, config.d_model, bias=config.include_bias, device=config.init_device |
| ) |
|
|
| |
| self.ff_out = nn.Linear( |
| int(self.act.output_multiplier * self.hidden_size), |
| config.d_model, |
| bias=config.include_bias, |
| device=config.init_device, |
| ) |
| self.ff_out._is_residual = True |
|
|
| self.flash_attn_func = None |
| if config.flash_attention: |
| try: |
| from flash_attn_interface import flash_attn_func |
| if not self.__class__.ATTN_INFO_SHOWN: print('MODEL USING FLASH_ATTENTION_3') |
| except ModuleNotFoundError: |
| from flash_attn import flash_attn_func |
| if not self.__class__.ATTN_INFO_SHOWN: print('MODEL USING FLASH_ATTENTION_2') |
| self.flash_attn_func = flash_attn_func |
| else: |
| if not self.__class__.ATTN_INFO_SHOWN: print('MODEL USING TORCH.SDPA ATTENTION') |
| self.__class__.ATTN_INFO_SHOWN = True |
|
|
| def reset_parameters(self): |
| if self.k_norm is not None: |
| self.k_norm.reset_parameters() |
| if self.q_norm is not None: |
| self.q_norm.reset_parameters() |
| init_weights( |
| self.config, |
| self.attn_out, |
| d=self.config.d_model, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
| init_weights( |
| self.config, |
| self.ff_out, |
| d=self.ff_out.in_features, |
| layer_id=self.layer_id, |
| type_of_module=ModuleType.out_module, |
| ) |
|
|
| def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): |
| if strategy == ActivationCheckpointingStrategy.fine_grained: |
| self._activation_checkpoint_fn = activation_checkpoint_function(self.config) |
| else: |
| self._activation_checkpoint_fn = None |
|
|
| @classmethod |
| def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: |
| target_dtype = input_dtype |
| |
| |
| |
| if bias.device.type == "cuda" and torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): |
| target_dtype = torch.get_autocast_cpu_dtype() |
| if bias.dtype != target_dtype: |
| bias = bias.to(target_dtype) |
| ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) |
| return bias |
|
|
| def _scaled_dot_product_attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Computes scaled dot product attention on query, key and value tensors, using an optional |
| attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. |
| """ |
| if self.flash_attn_func is not None and attn_mask is None: |
| r = self.flash_attn_func( |
| q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), **({'dropout_p': dropout_p} if dropout_p else {}), causal=False |
| ) |
| return r.transpose(1, 2) |
| else: |
| |
| assert k.size(1) == v.size(1) |
| num_kv_heads = k.size(1) |
| num_q_heads = q.size(1) |
| if num_q_heads != num_kv_heads: |
| assert num_q_heads % num_kv_heads == 0 |
| k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) |
| v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) |
|
|
| |
| return F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=dropout_p, |
| is_causal=False, |
| ) |
|
|
| def attention( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| attention_bias: Optional[torch.Tensor] = None, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| rope=None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| B, T, C = q.size() |
| dtype = k.dtype |
|
|
| |
| if self.q_norm is not None and self.k_norm is not None: |
| q = self.q_norm(q).to(dtype=dtype) |
| k = self.k_norm(k).to(dtype=dtype) |
|
|
| |
| |
| q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) |
| |
| k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) |
| |
| v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) |
|
|
| if layer_past is not None: |
| past_key, past_value = layer_past |
| k = torch.cat((past_key, k), dim=-2) |
| v = torch.cat((past_value, v), dim=-2) |
|
|
| present = (k, v) if use_cache else None |
| query_len, key_len = q.shape[-2], k.shape[-2] |
|
|
| if self.config.rope: |
| |
| |
| pos_sin, pos_cos = rope |
| q = RotaryEmbedding.apply_rotary_pos_emb( |
| pos_sin[:, :, key_len - query_len: key_len, :], |
| pos_cos[:, :, key_len - query_len: key_len, :], |
| q, |
| ).type_as(q) |
| k = RotaryEmbedding.apply_rotary_pos_emb(pos_sin, pos_cos, k).type_as(k) |
|
|
| if attention_bias is not None: |
| |
| |
| |
| |
| |
| attention_bias = self._cast_attn_bias( |
| attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype |
| ) |
|
|
| |
| |
| att = self._scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=0.0 if not self.training else self.config.attention_dropout, |
| is_causal=False, |
| ) |
|
|
| |
| att = att.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| |
| return self.attn_out(att), present |
|
|
| @abstractmethod |
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_bias: Optional[torch.FloatTensor] = None, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| raise NotImplementedError |
|
|
| @classmethod |
| def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock: |
| if config.block_type == BlockType.sequential: |
| return LLaDASequentialBlock(layer_id, config, cache) |
| elif config.block_type == BlockType.llama: |
| return LLaDALlamaBlock(layer_id, config, cache) |
| else: |
| raise NotImplementedError(f"Unknown block type: '{config.block_type}'") |
|
|
|
|
| class LLaDASequentialBlock(LLaDABlock): |
| """ |
| This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` |
| (plus another skip connection). |
| """ |
|
|
| def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): |
| super().__init__(layer_id, config, cache) |
| |
| self.attn_norm = LayerNorm.build(config) |
| self.ff_norm = LayerNorm.build(config) |
| |
| head_dim = config.d_model // config.n_heads |
| self.fused_dims = ( |
| config.d_model, |
| config.effective_n_kv_heads * head_dim, |
| config.effective_n_kv_heads * head_dim, |
| ) |
| self.att_proj = nn.Linear( |
| config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device |
| ) |
| |
| self.ff_proj = nn.Linear( |
| config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device |
| ) |
|
|
| def reset_parameters(self): |
| super().reset_parameters() |
| self.attn_norm.reset_parameters() |
| self.ff_norm.reset_parameters() |
| |
| init_weights( |
| self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module |
| ) |
| init_weights( |
| self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_bias: Optional[torch.Tensor] = None, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| rope=None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| |
| |
| |
| |
| |
| |
| |
| if self._activation_checkpoint_fn is not None: |
| q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split( |
| self.fused_dims, dim=-1 |
| ) |
| else: |
| q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1) |
|
|
| |
| if self._activation_checkpoint_fn is not None: |
| att, cache = self._activation_checkpoint_fn( |
| self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope |
| ) |
| else: |
| att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope) |
|
|
| |
| |
| x = x + self.dropout(att) |
|
|
| |
| |
| og_x = x |
| if self._activation_checkpoint_fn is not None: |
| x = self._activation_checkpoint_fn(self.ff_norm, x) |
| else: |
| x = self.ff_norm(x) |
| x = self.ff_proj(x) |
| if self._activation_checkpoint_fn is not None: |
| x = self._activation_checkpoint_fn(self.act, x) |
| else: |
| x = self.act(x) |
| x = self.ff_out(x) |
| x = self.dropout(x) |
| x = og_x + x |
|
|
| return x, cache |
|
|
|
|
| class LLaDALlamaBlock(LLaDABlock): |
| """ |
| This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` |
| (plus another skip connection). This block is similar to `LLaDASequentialBlock` |
| but some operations have slightly different implementations to imitate the |
| behavior of Llama. |
| """ |
|
|
| def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): |
| super().__init__(layer_id, config, cache) |
| |
| self.attn_norm = LayerNorm.build(config) |
| self.ff_norm = LayerNorm.build(config) |
| self.__cache = cache |
|
|
| |
| head_dim = config.d_model // config.n_heads |
| q_proj_out_dim = config.d_model |
| k_proj_out_dim = config.effective_n_kv_heads * head_dim |
| v_proj_out_dim = config.effective_n_kv_heads * head_dim |
| self.q_proj = nn.Linear( |
| config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device |
| ) |
| self.k_proj = nn.Linear( |
| config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device |
| ) |
| self.v_proj = nn.Linear( |
| config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device |
| ) |
|
|
| |
| self.ff_proj = nn.Linear( |
| config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device |
| ) |
| |
| self.up_proj = nn.Linear( |
| config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device |
| ) |
|
|
| def reset_parameters(self): |
| super().reset_parameters() |
| self.attn_norm.reset_parameters() |
| self.ff_norm.reset_parameters() |
| |
| init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) |
| init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) |
| init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) |
| init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) |
| init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_bias: Optional[torch.Tensor] = None, |
| layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| rope=None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| |
| |
| |
| |
| |
| |
| |
| x_normed = self.attn_norm(x) |
| q = self.q_proj(x_normed) |
| k = self.k_proj(x_normed) |
| v = self.v_proj(x_normed) |
|
|
| |
| if self._activation_checkpoint_fn is not None: |
| att, cache = self._activation_checkpoint_fn( |
| self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope |
| ) |
| else: |
| att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope) |
|
|
| |
| |
| x = x + self.dropout(att) |
|
|
| |
| |
| og_x = x |
| if self._activation_checkpoint_fn is not None: |
| x = self._activation_checkpoint_fn(self.ff_norm, x) |
| else: |
| x = self.ff_norm(x) |
| x, x_up = self.ff_proj(x), self.up_proj(x) |
| if self._activation_checkpoint_fn is not None: |
| x = self._activation_checkpoint_fn(self.act, x) |
| else: |
| x = self.act(x) |
| x = x * x_up |
| x = self.ff_out(x) |
| x = self.dropout(x) |
| x = og_x + x |
|
|
| return x, cache |
|
|
|
|
| class LLaDAOutput(NamedTuple): |
| logits: torch.FloatTensor |
| """ |
| A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities |
| for the next token *before* normalization via (log) softmax. |
| """ |
|
|
| attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] |
| """ |
| Attention keys and values from each block. |
| """ |
|
|
| hidden_states: Optional[Tuple[torch.Tensor]] |
| """ |
| Hidden states from each block. |
| """ |
|
|
|
|
| class LLaDAGenerateOutput(NamedTuple): |
| token_ids: torch.LongTensor |
| """ |
| The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`. |
| These do *not* include the original input IDs. |
| """ |
|
|
| scores: torch.FloatTensor |
| """ |
| The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`. |
| """ |
|
|
|
|
| class LLaDABlockGroup(nn.ModuleList): |
| def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None): |
| super().__init__(modules) |
| self.config = config |
| self.layer_offset = layer_offset |
| self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None |
| self._activation_checkpoint_fn = activation_checkpoint_function(self.config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_bias: Optional[torch.FloatTensor] = None, |
| layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
| attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
| for block_idx, block in enumerate(self): |
| layer_past = None if layers_past is None else layers_past[block_idx] |
| block_idx += self.layer_offset |
| if ( |
| (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two |
| and block_idx % 2 == 0 |
| ) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three |
| and block_idx % 3 == 0 |
| ) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four |
| and block_idx % 4 == 0 |
| ) |
| ): |
| |
| x, cache = self._activation_checkpoint_fn( |
| block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache |
| ) |
| else: |
| |
| x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache) |
| if attn_key_values is not None: |
| assert cache is not None |
| attn_key_values.append(cache) |
| return x, attn_key_values |
|
|
| def reset_parameters(self): |
| for block in self: |
| block.reset_parameters() |
|
|
| def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): |
| self.activation_checkpointing_strategy = strategy |
| for block in self: |
| block.set_activation_checkpointing(strategy) |
|
|
|
|
| class LLaDAModel(nn.Module): |
| def __init__(self, config: ModelConfig, init_params: bool = True): |
| super().__init__() |
| self.config = config |
| self.__cache = BufferCache() |
|
|
| |
| if self.config.alibi and self.config.flash_attention: |
| raise Exception("ALiBi is currently not supported with FlashAttention") |
|
|
| if self.config.alibi and self.config.rope: |
| raise Exception("ALiBi and RoPE are mutually exclusive") |
|
|
| if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: |
| if self.config.embedding_size < self.config.vocab_size: |
| raise Exception("embedding size should be at least as big as vocab size") |
| elif self.config.embedding_size % 128 != 0: |
| import warnings |
|
|
| warnings.warn( |
| "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning |
| ) |
|
|
| self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None |
| self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) |
|
|
| if not ( |
| 0 < self.config.block_group_size <= self.config.n_layers |
| and self.config.n_layers % self.config.block_group_size == 0 |
| ): |
| raise Exception("n layers must be divisible by block group size") |
|
|
| torch.backends.cuda.enable_flash_sdp(True) |
| torch.backends.cuda.enable_mem_efficient_sdp(False) |
|
|
| self.transformer = nn.ModuleDict( |
| dict( |
| wte=nn.Embedding( |
| config.embedding_size or config.vocab_size, config.d_model, device=config.init_device |
| ), |
| emb_drop=Dropout(config.embedding_dropout), |
| ln_f=LayerNorm.build(config), |
| ) |
| ) |
|
|
| blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)] |
| if self.config.block_group_size > 1: |
| block_groups = [ |
| LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size]) |
| for i in range(0, config.n_layers, config.block_group_size) |
| ] |
| self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) |
| else: |
| self.transformer.update({"blocks": nn.ModuleList(blocks)}) |
|
|
| if not (self.config.alibi or self.config.rope): |
| self.transformer.update( |
| {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} |
| ) |
| if not config.weight_tying: |
| self.transformer.update( |
| { |
| "ff_out": nn.Linear( |
| config.d_model, |
| config.embedding_size or config.vocab_size, |
| bias=config.include_bias, |
| device=config.init_device, |
| ) |
| } |
| ) |
| |
| if init_params and self.config.init_device != "meta": |
| self.reset_parameters() |
| self.__num_fwd_flops: Optional[int] = None |
|
|
| |
| if self.config.alibi: |
| get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) |
| self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) |
|
|
| |
| if self.config.rope: |
| self.rotary_emb = RotaryEmbedding(config, self.__cache) |
|
|
| def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]): |
| self.activation_checkpointing_strategy = strategy |
| if self.config.block_group_size != 1: |
| for block_group in self.transformer.block_groups: |
| block_group.set_activation_checkpointing(strategy) |
| else: |
| for block in self.transformer.blocks: |
| block.set_activation_checkpointing(strategy) |
|
|
| @property |
| def device(self) -> torch.device: |
| device: torch.device = self.transformer.wte.weight.device |
| if device.type == "meta": |
| return _non_meta_init_device(self.config) |
| else: |
| return device |
|
|
| def reset_parameters(self): |
| log.info("Initializing model parameters...") |
| |
| init_weights( |
| self.config, |
| self.transformer.wte, |
| std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0, |
| type_of_module=ModuleType.emb, |
| ) |
| if hasattr(self.transformer, "wpe"): |
| init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) |
|
|
| |
| self.transformer.ln_f.reset_parameters() |
|
|
| |
| if hasattr(self.transformer, "ff_out"): |
| init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) |
|
|
| |
| if self.config.block_group_size == 1: |
| for block in self.transformer.blocks: |
| block.reset_parameters() |
| else: |
| for block_group in self.transformer.block_groups: |
| block_group.reset_parameters() |
|
|
| def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ |
| -1 |
| ] >= seq_len: |
| if alibi_bias.device != device: |
| alibi_bias = alibi_bias.to(device) |
| self.__cache["alibi_attention_bias"] = alibi_bias |
| return alibi_bias |
| with torch.autocast(device.type, enabled=False): |
| alibi_bias = alibi_attention_bias(seq_len, self.config, device) |
| self.__cache["alibi_attention_bias"] = alibi_bias |
| return alibi_bias |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| input_embeddings: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| attention_bias: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| last_logits_only: bool = False, |
| output_hidden_states: Optional[bool] = None, |
| skip_final_steps: Optional[bool] = None, |
| position_ids=None, |
| ) -> LLaDAOutput: |
| """ |
| :param input_ids: A tensor of shape `(batch_size, seq_len)`. |
| :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input |
| embeddings. When provided, it is treated as the output of the input embedding layer. |
| :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates |
| which input IDs are masked. A `1` value in the mask means that |
| the corresponding input ID should *not* be ignored. A `0` means |
| that the corresponding input ID is masked. |
| |
| This has the same meaning as the `attention_mask` in HuggingFace's `transformers` |
| library. |
| :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, |
| `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used |
| to introduce causal or other biases. |
| |
| If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` |
| indicates that the i-th element in the sequence is allowed to attend to the j-th |
| element in the sequence. |
| |
| If the tensor is a float tensor, it will just be added to the attention |
| scores before the softmax. |
| |
| The default is causal, which corresponds to a lower-diagonal byte matrix of ones. |
| :param past_key_values: Pre-computed keys and values for each attention block. |
| Can be used to speed up sequential decoding. The `input_ids` which have |
| their past given to this model should not be passed as `input_ids` as they have already been computed. |
| :param use_cache: If `True`, return key and value tensors for each block. |
| :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. |
| This can speed up decoding when you only care about the next token. |
| """ |
| |
| assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM." |
| assert self.config.rope, "Rope must be used in Llama-Encoder for MDM." |
| assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM." |
|
|
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False |
|
|
| if past_key_values: |
| assert len(past_key_values) == self.config.n_layers |
|
|
| batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] |
| if past_key_values is None: |
| past_length = 0 |
| else: |
| past_length = past_key_values[0][0].size(-2) |
|
|
| |
| |
| x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings |
|
|
| if self.config.input_emb_norm: |
| x = x * (self.config.d_model**0.5) |
|
|
| if not (self.config.alibi or self.config.rope): |
| |
| |
| pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) |
| |
| pos_emb = self.transformer.wpe(pos) |
| x = pos_emb + x |
| rope = None |
| else: rope = self.rotary_emb.get_rotary_embedding(position_ids.max()+1, x.device, position_ids, x.dtype) |
|
|
| |
| |
| x = self.transformer.emb_drop(x) |
|
|
| |
| if attention_mask is not None and 0.0 in attention_mask: |
| |
| attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] |
| attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min |
| else: |
| attention_mask = None |
|
|
| |
| if ( |
| attention_bias is not None |
| or attention_mask is not None |
| or self.config.alibi |
| |
| |
| |
| or past_key_values is not None |
| ): |
| if attention_bias is None and self.config.alibi: |
| attention_bias = get_causal_attention_bias( |
| self.__cache, past_length + seq_len, x.device |
| ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) |
| elif attention_bias is None: |
| attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) |
| elif attention_bias.dtype in (torch.int8, torch.bool): |
| attention_bias = attention_bias.to(dtype=torch.float) |
| attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) |
|
|
| |
| mask_len = seq_len |
| if attention_mask is not None: |
| mask_len = attention_mask.shape[-1] |
| elif past_key_values is not None: |
| mask_len = past_key_values[0][0].shape[-2] + seq_len |
| attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) |
|
|
| |
| if attention_mask is not None: |
| attention_bias = attention_bias + attention_mask |
| |
| |
| |
| ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) |
|
|
| attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None |
|
|
| |
| all_hidden_states = [] |
|
|
| |
| if self.config.block_group_size == 1: |
| for block_idx, block in enumerate(self.transformer.blocks): |
| if output_hidden_states: |
| |
| all_hidden_states.append(x) |
|
|
| layer_past = None if past_key_values is None else past_key_values[block_idx] |
| if ( |
| (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two |
| and block_idx % 2 == 0 |
| ) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three |
| and block_idx % 3 == 0 |
| ) |
| or ( |
| self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four |
| and block_idx % 4 == 0 |
| ) |
| ): |
| |
| x, cache = self._activation_checkpoint_fn( |
| block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope |
| ) |
| else: |
| |
| x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache, rope=rope) |
| if attn_key_values is not None: |
| assert cache is not None |
| attn_key_values.append(cache) |
| else: |
| for group_idx, block_group in enumerate(self.transformer.block_groups): |
| if output_hidden_states: |
| |
| all_hidden_states.append(x) |
|
|
| layers_past = ( |
| None |
| if past_key_values is None |
| else past_key_values[ |
| group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size |
| ] |
| ) |
| x, cache = block_group( |
| x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache, rope=rope |
| ) |
| if attn_key_values is not None: |
| assert cache is not None |
| attn_key_values.extend(cache) |
|
|
| if last_logits_only: |
| |
| x = x[:, -1, :].unsqueeze(1) |
|
|
| if skip_final_steps: |
| return LLaDAOutput(logits=x, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) |
| |
| |
| x = self.transformer.ln_f(x) |
| if output_hidden_states: |
| |
| all_hidden_states.append(x) |
|
|
| |
| |
| if self.config.weight_tying: |
| logits = F.linear(x, self.transformer.wte.weight, None) |
| else: |
| logits = self.transformer.ff_out(x) |
| if self.config.scale_logits: |
| logits.mul_(1 / math.sqrt(self.config.d_model)) |
|
|
| return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) |
|
|
|
|
| def create_model_config_from_pretrained_config(config: LLaDAConfig): |
| """ |
| Utility function |
| """ |
|
|
| kwargs = {} |
| for field in fields(ModelConfig): |
| kwargs[field.name] = getattr(config, field.name) |
|
|
| model_config = ModelConfig(**kwargs) |
| return model_config |
|
|
|
|
| class LLaDAModelLM(PreTrainedModel): |
| """ |
| Extremely barebones HF model wrapper. |
| """ |
| using_custom_rope_version = 'gg2m' |
| config_class = LLaDAConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"] |
|
|
| def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False): |
| super().__init__(config) |
|
|
| if not model: |
| model_config = create_model_config_from_pretrained_config(config) |
| |
| model_config.init_device = "cpu" |
| self.model = LLaDAModel(model_config, init_params=init_params) |
| else: |
| self.model = model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| attention_bias: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| skip_final_steps: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| position_ids=None, |
| cache_position: Optional[Cache] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| if use_cache is None: |
| use_cache = self.config.use_cache |
|
|
| if output_attentions: |
| raise ValueError("output_attentions is not yet supported in LLaDA") |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model.forward( |
| input_ids=input_ids, |
| input_embeddings=inputs_embeds, |
| attention_mask=attention_mask, |
| attention_bias=attention_bias, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_hidden_states=output_hidden_states, |
| skip_final_steps=skip_final_steps, |
| position_ids=position_ids, |
| ) |
|
|
| logits = outputs.logits |
| hidden_states = outputs.hidden_states |
|
|
| loss = None |
| if labels is not None: |
| import warnings |
| warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning) |
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| logits=logits, |
| past_key_values=outputs.attn_key_values, |
| hidden_states=hidden_states, |
| ) |
|
|
| def can_generate(self) -> bool: |
| return True |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs |
| ): |
| if past_key_values: |
| |
| input_ids = input_ids[:, -1:] |
| model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values} |
|
|
| model_inputs.update(kwargs) |
| model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache) |
| return model_inputs |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| return self.model.transformer.wte |
|
|
| def set_input_embeddings(self, value: torch.nn.Module): |
| self.model.transformer.wte = value |
|
|
| def get_output_embeddings(self): |
| if self.config.weight_tying: |
| return self.model.transformer.wte |
| else: |
| return self.model.transformer.ff_out |
|
|
| def set_output_embeddings(self, value: torch.nn.Module): |
| if self.config.weight_tying: |
| self.model.transformer.wte = value |
| else: |
| self.model.transformer.ff_out = value |
|
|
| def tie_weights(self): |
| if self.config.weight_tying: |
| self.model.transformer.ff_out = self.model.transformer.wte |
|
|
| |
| AutoModel.register(LLaDAConfig, LLaDAModelLM) |