|
""" |
|
Transformer blocks for the LLM model. |
|
""" |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import flax.linen as nn |
|
from typing import Optional, Tuple, Dict, Any, Callable, Union |
|
import math |
|
|
|
from model.attention import MultiHeadAttention, MultiQueryAttention, RotaryMultiQueryAttention |
|
|
|
|
|
class FeedForward(nn.Module): |
|
""" |
|
Feed-forward network with SwiGLU activation. |
|
|
|
Attributes: |
|
dim: Input and output dimension |
|
hidden_dim: Hidden dimension |
|
dropout_rate: Dropout probability |
|
dtype: Data type for computations |
|
""" |
|
dim: int |
|
hidden_dim: int |
|
dropout_rate: float = 0.0 |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.gate_proj = nn.Dense( |
|
features=self.hidden_dim, |
|
dtype=self.dtype, |
|
kernel_init=nn.initializers.normal(stddev=0.02), |
|
name="gate_proj" |
|
) |
|
|
|
self.up_proj = nn.Dense( |
|
features=self.hidden_dim, |
|
dtype=self.dtype, |
|
kernel_init=nn.initializers.normal(stddev=0.02), |
|
name="up_proj" |
|
) |
|
|
|
self.down_proj = nn.Dense( |
|
features=self.dim, |
|
dtype=self.dtype, |
|
kernel_init=nn.initializers.normal(stddev=0.02), |
|
name="down_proj" |
|
) |
|
|
|
self.dropout = nn.Dropout(rate=self.dropout_rate) |
|
|
|
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray: |
|
""" |
|
Apply feed-forward network. |
|
|
|
Args: |
|
x: Input tensor [batch_size, seq_len, dim] |
|
deterministic: Whether to use deterministic operations (no dropout) |
|
|
|
Returns: |
|
Output tensor [batch_size, seq_len, dim] |
|
""" |
|
|
|
gate = self.gate_proj(x) |
|
gate = jax.nn.silu(gate) |
|
|
|
up = self.up_proj(x) |
|
|
|
|
|
hidden = gate * up |
|
|
|
|
|
output = self.down_proj(hidden) |
|
|
|
|
|
output = self.dropout(output, deterministic=deterministic) |
|
|
|
return output |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
""" |
|
Transformer block with attention and feed-forward network. |
|
|
|
Attributes: |
|
dim: Hidden dimension |
|
num_heads: Number of attention heads |
|
hidden_dim: Hidden dimension in feed-forward network |
|
dropout_rate: Dropout probability |
|
attention_dropout_rate: Dropout probability for attention |
|
layer_norm_epsilon: Epsilon for layer normalization |
|
dtype: Data type for computations |
|
""" |
|
dim: int |
|
num_heads: int |
|
hidden_dim: int |
|
dropout_rate: float = 0.0 |
|
attention_dropout_rate: float = 0.0 |
|
layer_norm_epsilon: float = 1e-5 |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
self.input_layernorm = nn.LayerNorm( |
|
epsilon=self.layer_norm_epsilon, |
|
dtype=self.dtype, |
|
name="input_layernorm" |
|
) |
|
|
|
self.post_attention_layernorm = nn.LayerNorm( |
|
epsilon=self.layer_norm_epsilon, |
|
dtype=self.dtype, |
|
name="post_attention_layernorm" |
|
) |
|
|
|
|
|
self.attention = MultiHeadAttention( |
|
dim=self.dim, |
|
num_heads=self.num_heads, |
|
dropout_rate=self.attention_dropout_rate, |
|
dtype=self.dtype, |
|
name="attention" |
|
) |
|
|
|
|
|
self.feed_forward = FeedForward( |
|
dim=self.dim, |
|
hidden_dim=self.hidden_dim, |
|
dropout_rate=self.dropout_rate, |
|
dtype=self.dtype, |
|
name="feed_forward" |
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(rate=self.dropout_rate) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray, ...]: |
|
""" |
|
Apply transformer block. |
|
|
|
Args: |
|
hidden_states: Input tensor [batch_size, seq_len, dim] |
|
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len] |
|
position_ids: Position indices [batch_size, seq_len] |
|
past_key_value: Cached key and value tensors for incremental decoding |
|
output_attentions: Whether to return attention weights |
|
use_cache: Whether to use cached key and values |
|
deterministic: Whether to use deterministic operations (no dropout) |
|
|
|
Returns: |
|
Tuple of (output, attention_weights, present_key_value) |
|
""" |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
attention_outputs = self.attention( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = attention_outputs[0] |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
hidden_states = self.feed_forward(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) + attention_outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
""" |
|
Transformer layer with multi-query attention and feed-forward network. |
|
|
|
Attributes: |
|
dim: Hidden dimension |
|
num_query_heads: Number of query heads |
|
num_kv_heads: Number of key-value heads |
|
hidden_dim: Hidden dimension in feed-forward network |
|
max_seq_len: Maximum sequence length for RoPE |
|
dropout_rate: Dropout probability |
|
attention_dropout_rate: Dropout probability for attention |
|
layer_norm_epsilon: Epsilon for layer normalization |
|
use_rope: Whether to use rotary position embeddings |
|
dtype: Data type for computations |
|
""" |
|
dim: int |
|
num_query_heads: int |
|
num_kv_heads: int = 1 |
|
hidden_dim: int = None |
|
max_seq_len: int = 4096 |
|
dropout_rate: float = 0.0 |
|
attention_dropout_rate: float = 0.0 |
|
layer_norm_epsilon: float = 1e-5 |
|
use_rope: bool = True |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
if self.hidden_dim is None: |
|
self.actual_hidden_dim = 4 * self.dim |
|
else: |
|
self.actual_hidden_dim = self.hidden_dim |
|
|
|
|
|
self.input_layernorm = nn.LayerNorm( |
|
epsilon=self.layer_norm_epsilon, |
|
dtype=self.dtype, |
|
name="input_layernorm" |
|
) |
|
|
|
self.post_attention_layernorm = nn.LayerNorm( |
|
epsilon=self.layer_norm_epsilon, |
|
dtype=self.dtype, |
|
name="post_attention_layernorm" |
|
) |
|
|
|
|
|
if self.use_rope: |
|
self.attention = RotaryMultiQueryAttention( |
|
dim=self.dim, |
|
num_query_heads=self.num_query_heads, |
|
num_kv_heads=self.num_kv_heads, |
|
max_seq_len=self.max_seq_len, |
|
dropout_rate=self.attention_dropout_rate, |
|
dtype=self.dtype, |
|
name="attention" |
|
) |
|
else: |
|
self.attention = MultiQueryAttention( |
|
dim=self.dim, |
|
num_query_heads=self.num_query_heads, |
|
num_kv_heads=self.num_kv_heads, |
|
dropout_rate=self.attention_dropout_rate, |
|
dtype=self.dtype, |
|
name="attention" |
|
) |
|
|
|
|
|
self.feed_forward = FeedForward( |
|
dim=self.dim, |
|
hidden_dim=self.actual_hidden_dim, |
|
dropout_rate=self.dropout_rate, |
|
dtype=self.dtype, |
|
name="feed_forward" |
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(rate=self.dropout_rate) |
|
|
|
def __call__( |
|
self, |
|
hidden_states: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
deterministic: bool = True, |
|
) -> Tuple[jnp.ndarray, ...]: |
|
""" |
|
Apply transformer layer. |
|
|
|
Args: |
|
hidden_states: Input tensor [batch_size, seq_len, dim] |
|
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len] |
|
position_ids: Position indices [batch_size, seq_len] |
|
past_key_value: Cached key and value tensors for incremental decoding |
|
output_attentions: Whether to return attention weights |
|
use_cache: Whether to use cached key and values |
|
deterministic: Whether to use deterministic operations (no dropout) |
|
|
|
Returns: |
|
Tuple of (output, attention_weights, present_key_value) |
|
""" |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
attention_outputs = self.attention( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = attention_outputs[0] |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
hidden_states = self.feed_forward(hidden_states, deterministic=deterministic) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) + attention_outputs[1:] |
|
|
|
return outputs |
|
|