mistral-nanotron / modeling_mistral.py
thomwolf's picture
thomwolf HF staff
update all
aa5ff8c
# coding=utf-8
# Copyright 2018 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mistral model.
"""
from typing import Dict, Optional, Union
import inspect
import torch
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from nanotron import distributed as dist
from nanotron import logging
from nanotron.config import ParallelismArgs, RecomputeGranularity
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.layer_norm import TritonRMSNorm
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import NanotronParameter
from nanotron.parallel.pipeline_parallel.block import (
PipelineBlock,
TensorPointer,
)
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.utils import checkpoint_method
from nanotron.nn.activations import ACT2FN
from torch import nn
from config_mistral_7b import MistralConfig
logger = logging.get_logger(__name__)
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_varlen_func).parameters)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 10000.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.end = end
self.theta = theta
# TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
# TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
self.freqs_cis: torch.Tensor
self._initialized_buffer = False
def init_rotary_embeddings(self):
if self._initialized_buffer is True:
# Buffer if already initialized
return
self.register_buffer(
"freqs_cis",
torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
persistent=False,
)
assert self.freqs_cis.device.type == "cuda"
# TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
if self.freqs_cis.dtype != torch.float:
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta
** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
)
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
freqs = torch.view_as_real(complex_freqs)
self.freqs_cis.copy_(freqs)
self._initialized_buffer = True
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
batch_size, seq_length, num_heads, inner_dim = x.shape
while (
position_ids is not None and position_ids[-1, -1] >= self.end
) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
self.end *= 2
self._initialized_buffer = False
if self._initialized_buffer is False:
# print(f"Initializing rotary embeddings with end={self.end}")
self.init_rotary_embeddings()
dtype = x.dtype
assert inner_dim % 2 == 0
x = x.view(
batch_size, seq_length, num_heads, inner_dim // 2, 2
) # [batch_size, q_length, num_heads, inner_dim]
if x.dtype == torch.bfloat16:
x = x.float()
complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
if position_ids is None:
freqs_cis = self.freqs_cis[None, :seq_length, None, :]
else:
# TODO(kunhao): Should None follow the num_heads dimension?
if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
complex_freqs = torch.view_as_complex(freqs_cis)
x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
return x_out.type(dtype)
class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
self.act = ACT2FN[act_fn_name]
def forward(self, merged_states: torch.Tensor):
gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
return self.act(gate_states) * up_states
class MLP(nn.Module):
def __init__(
self,
config: MistralConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
):
super().__init__()
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
gate_up_contiguous_chunks = (
config.intermediate_size, # shape of gate_linear
config.intermediate_size, # shape of up_linear
)
self.gate_up_proj = TensorParallelColumnLinear(
config.hidden_size,
2 * config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return {"hidden_states": hidden_states}
class CoreAttention(nn.Module):
def __init__(self, config: MistralConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
super().__init__()
# TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
assert (
config.hidden_size % config.num_attention_heads == 0
), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.dropout = config.attn_pdrop
self.checkpoint_attention = False # Because flash_attn already does checkpointing
if config.sliding_window_size is not None:
assert (
_flash_supports_window_size
), "Current version of flash-attn doesn't support sliding window: `pip install flash-attn>=2.3`"
self.sliding_window_size = config.sliding_window_size # if layer_idx not in config.global_attn_layers else None
@checkpoint_method(attr_name="checkpoint_attention")
def forward(
self,
query_states: torch.Tensor, # [batch_size * q_length, num_heads, inner_dim]
key_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
value_states: torch.Tensor, # [batch_size * kv_length, 1, inner_dim]
q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
):
# TODO @thomasw21: Compute once, instead of computing for each layers.
cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
causal = False if q_sequence_mask.shape[1] == 1 else True
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_sequence_mask.shape[1],
max_seqlen_k=kv_sequence_mask.shape[1],
dropout_p=self.dropout if self.training else 0.0,
softmax_scale=None, # defaults to 1/sqrt(d_qk)
causal=causal,
window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1),
return_attn_probs=False,
)
return attn_output
def pad_to_right(tensor, mask, new_tensor=None):
"""Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
Args:
tensor: (batch_size, seqlen, d1, d2)
mask: (batch_size, seqlen)
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
Returns:
new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
right_padded_mask: (batch_size, seqlen)
"""
# First, we need to find the number of padding for each row
unpad_seqlens = mask.sum(1)
# Then, we need to find the maximum length of the tensor
max_seqlen = mask.shape[1]
# We can then create the indices to select the padded values
# The indices are the same for each row
indices = torch.arange(max_seqlen, device=mask.device)
# We can then create the mask for the padded values
right_padded_mask = indices < unpad_seqlens[:, None]
# We select the useful values
useful_values = tensor[mask]
# We create the new tensor (if not provided)
new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
# We fill the new tensor with the useful values
new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
return new_tensor, right_padded_mask
class CausalSelfAttention(nn.Module, AttachableStore):
def __init__(
self,
config: MistralConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
assert (
config.num_attention_heads % tp_pg.size() == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
try:
assert (
config.num_key_value_heads % tp_pg.size() == 0
), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
except AttributeError:
log_rank(
"WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
logger=logger,
level=logging.WARNING,
rank=0,
)
# If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
config.num_key_value_heads = config.num_attention_heads
assert (
config.num_attention_heads % config.num_key_value_heads == 0
), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
self.n_repeats = config.num_attention_heads // config.num_key_value_heads
self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
self.d_qk = config.hidden_size // config.num_attention_heads
self.d_v = config.hidden_size // config.num_attention_heads
self.d_model = config.hidden_size
# TODO @thomasw21: refactor so that we store that default in a single place.
tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
# build the slice config for self.qkv for save/load
# shard are done within the contiguous chunk
qkv_contiguous_chunks = (
config.num_attention_heads * self.d_qk, # shape of q
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
self.qkv_proj = TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta
)
# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
)
self.attention = CoreAttention(
config,
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.prefill_kv_len = (
config.max_position_embeddings
) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
def forward(
self,
hidden_states, # [seq_length, batch_size, hidden_size]
sequence_mask, # [batch_size, seq_length]
):
qkv_states = self.qkv_proj(
hidden_states
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape
if self.is_gqa:
query_states, key_states, value_states = torch.split(
qkv_states,
[
self.n_local_q_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
self.n_local_kv_heads * self.d_qk,
],
dim=-1,
)
query_states = (
query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
)
key_states = (
key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
value_states = (
value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
)
else:
query_states, key_states, value_states = (
qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
.permute(2, 1, 0, 3, 4)
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
store = self.get_local_store()
if store is not None: # Inference case
# Double check that we use store only at inference time
assert key_states.requires_grad is False
assert value_states.requires_grad is False
if "position_offsets" in store:
old_position_offsets = store["position_offsets"]
position_ids = old_position_offsets[:, None] + sequence_mask
else:
position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
position_offsets = position_ids[:, -1]
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
if "key" not in store:
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
# but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
assert ~(
sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
# preallocate k_cache, v_cache to self.prefill_kv_len
k_cache = torch.zeros(
(
batch_size,
self.prefill_kv_len,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
)
v_cache = torch.zeros(
(batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
dtype=query_states.dtype,
device=query_states.device,
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
output_unpad = flash_attn_varlen_func(
q=query_unpad, # (total_q, n_local_q_heads, d_qk)
k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True, # True in prefill phase, False in subsequent phases
return_attn_probs=False,
) # (total_unpadded, n_local_q_heads, d_v)
attention_output = bert_padding.pad_input(
output_unpad, indices_q, batch_size, q_length
) # (batch_size, q_length, n_local_q_heads, d_v)
pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
else:
# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
v_cache = store["value"]
# NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
# Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
if self.rotary_embedding.end > old_rotary_embed_end:
k_cache = torch.cat(
[
k_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_qk,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
v_cache = torch.cat(
[
v_cache,
torch.zeros(
(
batch_size,
self.rotary_embedding.end - old_rotary_embed_end,
self.n_local_kv_heads,
self.d_v,
),
dtype=query_states.dtype,
device=query_states.device,
),
],
dim=1,
)
assert (
k_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
assert (
v_cache.shape[1] == self.rotary_embedding.end
), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
# [batch_size, seq_length, num_heads, d_qk]
query_states = query_states.view(
batch_size, q_length, self.n_local_q_heads, self.d_qk
) # [batch_size, q_length, self.n_heads, d_qk]
kv_length = key_states.shape[1]
key_states = key_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size, kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size, kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size, kv_length, self.n_heads, d_v]
attention_output = flash_attn_with_kvcache(
query_states,
k_cache,
v_cache,
key_states,
value_states,
rotary_cos=None,
rotary_sin=None,
# TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0)
cache_seqlens=position_offsets.contiguous(),
softmax_scale=None,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
)
store.update(
{
"key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
"value": v_cache,
"position_offsets": position_offsets,
}
)
else: # Training case
# Apply rotary embeddings to query/key states
# NOTE: The layout is different from models/mistral.py which is [batch_size, num_heads, seq_length, d_qk]
# Here it is, [batch_size, seq_length, num_heads, d_qk]
# [2, batch_size, seq_length, num_heads, d_qk]
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
# [batch_size, seq_length, 2, num_heads, d_qk]
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)
q_sequence_mask = sequence_mask
kv_sequence_mask = sequence_mask
kv_length = key_states.shape[1]
# [batch_size, seq_length, num_heads, d_qk]
# Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
query_states = query_states.view(
batch_size * q_length, self.n_local_q_heads, self.d_qk
) # [batch_size * q_length, self.n_heads, d_qk]
key_states = key_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_qk
) # [batch_size * kv_length, self.n_heads, d_qk]
value_states = value_states.view(
batch_size * kv_length, self.n_local_kv_heads, self.d_v
) # [batch_size * kv_length, self.n_heads, d_v]
attention_output = self.attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
q_sequence_mask=q_sequence_mask,
kv_sequence_mask=kv_sequence_mask,
)
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
output = self.o_proj(attention_output)
return {"hidden_states": output, "sequence_mask": sequence_mask}
class MistralDecoderLayer(nn.Module):
def __init__(
self,
config: MistralConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual
return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
}
class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: MistralConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
self.token_embedding = TensorParallelEmbedding(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
padding_idx=config.pad_token_id,
pg=tp_pg,
mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
)
self.pg = tp_pg
def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
store = self.get_local_store()
if store is not None:
if "past_length" in store:
past_length = store["past_length"]
else:
past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
# Store new past_length in store
store["past_length"] = past_length + cumsum_mask[:, -1]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)
return {"input_embeds": input_embeds}
class MistralModel(nn.Module):
"""Build pipeline graph"""
def __init__(
self,
config: MistralConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
):
super().__init__()
# Declare all the nodes
self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
self.parallel_context = parallel_context
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
self.token_position_embeddings = PipelineBlock(
p2p=self.p2p,
module_builder=Embedding,
module_kwargs={
"tp_pg": parallel_context.tp_pg,
"config": config,
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
p2p=self.p2p,
module_builder=MistralDecoderLayer,
module_kwargs={
"config": config,
"parallel_config": parallel_config,
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "sequence_mask"},
module_output_keys={"hidden_states", "sequence_mask"},
)
for layer_idx in range(config.num_hidden_layers)
]
)
self.final_layer_norm = PipelineBlock(
p2p=self.p2p,
module_builder=TritonRMSNorm,
module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
module_input_keys={"input"},
module_output_keys={"hidden_states"},
) # TODO
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
module_builder=TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
"pg": parallel_context.tp_pg,
"bias": False,
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
},
module_input_keys={"x"},
module_output_keys={"logits"},
)
self.cast_to_fp32 = PipelineBlock(
p2p=self.p2p,
module_builder=lambda: lambda x: x.float(),
module_kwargs={},
module_input_keys={"x"},
module_output_keys={"output"},
)
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
def forward_with_hidden_states(
self,
input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size
d_qkv = model_config.hidden_size // model_config.num_attention_heads
block_compute_costs = {
# CausalSelfAttention (qkv proj + attn out) + MLP
MistralDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
world_size = self.parallel_context.world_pg.size()
try:
num_key_values_heads = self.config.num_key_value_heads
except AttributeError:
num_key_values_heads = self.config.num_attention_heads
model_flops, hardware_flops = get_flops(
num_layers=self.config.num_hidden_layers,
hidden_size=self.config.hidden_size,
num_heads=self.config.num_attention_heads,
num_key_value_heads=num_key_values_heads,
vocab_size=self.config.vocab_size,
ffn_hidden_size=self.config.intermediate_size,
seq_len=sequence_length,
batch_size=global_batch_size,
recompute_granularity=self.parallel_config.recompute_granularity,
)
model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
return model_flops_per_s, hardware_flops_per_s
@torch.jit.script
def masked_mean(loss, label_mask, dtype):
# type: (Tensor, Tensor, torch.dtype) -> Tensor
return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
class Loss(nn.Module):
def __init__(self, tp_pg: dist.ProcessGroup):
super().__init__()
self.tp_pg = tp_pg
def forward(
self,
sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
label_ids: torch.Tensor, # [batch_size, seq_length]
label_mask: torch.Tensor, # [batch_size, seq_length]
) -> Dict[str, torch.Tensor]:
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
loss = sharded_cross_entropy(
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
).transpose(0, 1)
# TODO @thomasw21: It's unclear what kind of normalization we want to do.
loss = masked_mean(loss, label_mask, dtype=torch.float)
# I think indexing causes a sync we don't actually want
# loss = loss[label_mask].sum()
return {"loss": loss}
class MistralForTraining(NanotronModel):
def __init__(
self,
config: MistralConfig,
parallel_context: ParallelContext,
parallel_config: Optional[ParallelismArgs],
random_states: Optional[RandomStates] = None,
):
super().__init__()
import warnings
self.model = MistralModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
self.loss = PipelineBlock(
p2p=self.model.p2p,
module_builder=Loss,
module_kwargs={"tp_pg": parallel_context.tp_pg},
module_input_keys={
"sharded_logits",
"label_ids",
"label_mask",
},
module_output_keys={"loss"},
)
self.parallel_context = parallel_context
self.config = config
self.parallel_config = parallel_config
def forward(
self,
input_ids: Union[torch.Tensor, TensorPointer],
input_mask: Union[torch.Tensor, TensorPointer],
label_ids: Union[torch.Tensor, TensorPointer],
label_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
sharded_logits = self.model(
input_ids=input_ids,
input_mask=input_mask,
)
loss = self.loss(
sharded_logits=sharded_logits,
label_ids=label_ids,
label_mask=label_mask,
)["loss"]
return {"loss": loss}
@torch.no_grad()
def init_model_randomly(self, init_method, scaled_init_method):
"""Initialize model parameters randomly.
Args:
init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/
Note:
Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
"""
model = self
initialized_parameters = set()
# Handle tensor parallelism
module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
# Fix the root_model
module_id_to_prefix[id(model)] = ""
for module_name, module in model.named_modules():
if isinstance(module, TensorParallelColumnLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
if "weight" == param_name:
init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelRowLinear):
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
name for name, _ in module.named_parameters()
}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
if "weight" == param_name:
scaled_init_method(param)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TritonRMSNorm):
assert {"weight"} == {name for name, _ in module.named_parameters()}
for param_name, param in module.named_parameters():
assert isinstance(param, NanotronParameter)
if param.is_tied:
tied_info = param.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.{param_name}"
if full_param_name in initialized_parameters:
# Already initialized
continue
if "weight" == param_name:
# TODO @thomasw21: Sometimes we actually want 0
param.fill_(1)
elif "bias" == param_name:
param.zero_()
else:
raise ValueError(f"Who the fuck is {param_name}?")
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
elif isinstance(module, TensorParallelEmbedding):
# TODO @thomasw21: Handle tied embeddings
# Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
# What it does:
# - instantiate a buffer of the `full size` in fp32
# - run init method on it
# - shard result to get only a specific shard
# Instead I'm lazy and just going to run init_method, since they are scalar independent
assert {"weight"} == {name for name, _ in module.named_parameters()}
assert isinstance(module.weight, NanotronParameter)
if module.weight.is_tied:
tied_info = module.weight.get_tied_info()
full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
module_id_to_prefix=module_id_to_prefix
)
else:
full_param_name = f"{module_name}.weight"
if full_param_name in initialized_parameters:
# Already initialized
continue
init_method(module.weight)
assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
assert initialized_parameters == {
param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
if param.is_tied
else name
for name, param in model.named_parameters()
}, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
return self.model.get_block_compute_costs()
def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""Get flops per second for a given model"""
return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
def get_flops(
num_layers,
hidden_size,
num_heads,
vocab_size,
seq_len,
kv_channels=None,
ffn_hidden_size=None,
batch_size=1,
recompute_granularity=None,
glu_activation=False,
):
"""Counts flops in an decoder-only model
Args:
num_layers: number of decoder layers
hidden_size: hidden size of the model
num_heads: number of heads in the model
num_key_value_heads: number of key/value heads in the model
ffn_hidden_size: hidden size of the FFN
vocab_size: size of the vocabulary
seq_len: sequence length of the decoder
batch_size: batch size
recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
Returns:
model_flops: flops in the model (should be independent of the hardware and model implementation)
hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
"""
if kv_channels is None:
assert hidden_size % num_heads == 0
kv_channels = hidden_size // num_heads
if ffn_hidden_size is None:
ffn_hidden_size = 4 * hidden_size
# In the following we mark the reduced dimension with parentheses
# decoder
# self attention (MQA)
## q projection
decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels
## kv projection, shared across heads
decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels
## qk logits
decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len
### SWA (sliding window attention / local attention)
# window_size = 4096
# decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size
## v logits
decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels
# decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels
## attn out
decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size
# FF
## 1st layer
decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
if glu_activation:
# 3 matmuls instead of 2 in FFN
# ref. https://arxiv.org/pdf/2002.05202.pdf
# Used for example in T5 v1.1
decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
## 2nd layer
decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
decoder_flops_fwd = (
decoder_q_proj_flops_fwd
+ decoder_kv_proj_flops_fwd
+ decoder_qk_logits_flops_fwd
+ decoder_v_logits_flops_fwd
+ decoder_attn_out_flops_fwd
+ decoder_ffn_1_flops_fwd
+ decoder_ffn_2_flops_fwd
)
# lm head
lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
# the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
# both input and weight tensors
model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
if recompute_granularity is None:
hardware_flops = model_flops
elif recompute_granularity is RecomputeGranularity.FULL:
# Note: we don't recompute lm head activs
hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
elif recompute_granularity is RecomputeGranularity.SELECTIVE:
# all terms with s^2 are flops that are recomputed
# ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
hardware_flops = model_flops + recomputed_decoder_flops
else:
raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
return model_flops, hardware_flops