File size: 10,791 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
"""
Transformer blocks for the LLM model.
"""
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Optional, Tuple, Dict, Any, Callable, Union
import math
from model.attention import MultiHeadAttention, MultiQueryAttention, RotaryMultiQueryAttention
class FeedForward(nn.Module):
"""
Feed-forward network with SwiGLU activation.
Attributes:
dim: Input and output dimension
hidden_dim: Hidden dimension
dropout_rate: Dropout probability
dtype: Data type for computations
"""
dim: int
hidden_dim: int
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.gate_proj = nn.Dense(
features=self.hidden_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="gate_proj"
)
self.up_proj = nn.Dense(
features=self.hidden_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="up_proj"
)
self.down_proj = nn.Dense(
features=self.dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="down_proj"
)
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
"""
Apply feed-forward network.
Args:
x: Input tensor [batch_size, seq_len, dim]
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Output tensor [batch_size, seq_len, dim]
"""
# SwiGLU activation
gate = self.gate_proj(x)
gate = jax.nn.silu(gate)
up = self.up_proj(x)
# Element-wise multiplication
hidden = gate * up
# Project back to input dimension
output = self.down_proj(hidden)
# Apply dropout
output = self.dropout(output, deterministic=deterministic)
return output
class TransformerBlock(nn.Module):
"""
Transformer block with attention and feed-forward network.
Attributes:
dim: Hidden dimension
num_heads: Number of attention heads
hidden_dim: Hidden dimension in feed-forward network
dropout_rate: Dropout probability
attention_dropout_rate: Dropout probability for attention
layer_norm_epsilon: Epsilon for layer normalization
dtype: Data type for computations
"""
dim: int
num_heads: int
hidden_dim: int
dropout_rate: float = 0.0
attention_dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-5
dtype: jnp.dtype = jnp.float32
def setup(self):
# Layer normalization
self.input_layernorm = nn.LayerNorm(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype,
name="input_layernorm"
)
self.post_attention_layernorm = nn.LayerNorm(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype,
name="post_attention_layernorm"
)
# Attention
self.attention = MultiHeadAttention(
dim=self.dim,
num_heads=self.num_heads,
dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
name="attention"
)
# Feed-forward network
self.feed_forward = FeedForward(
dim=self.dim,
hidden_dim=self.hidden_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
name="feed_forward"
)
# Dropout
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply transformer block.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len]
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
# Self-attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_outputs = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
deterministic=deterministic,
)
hidden_states = attention_outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Feed-forward network
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,) + attention_outputs[1:]
return outputs
class TransformerLayer(nn.Module):
"""
Transformer layer with multi-query attention and feed-forward network.
Attributes:
dim: Hidden dimension
num_query_heads: Number of query heads
num_kv_heads: Number of key-value heads
hidden_dim: Hidden dimension in feed-forward network
max_seq_len: Maximum sequence length for RoPE
dropout_rate: Dropout probability
attention_dropout_rate: Dropout probability for attention
layer_norm_epsilon: Epsilon for layer normalization
use_rope: Whether to use rotary position embeddings
dtype: Data type for computations
"""
dim: int
num_query_heads: int
num_kv_heads: int = 1
hidden_dim: int = None
max_seq_len: int = 4096
dropout_rate: float = 0.0
attention_dropout_rate: float = 0.0
layer_norm_epsilon: float = 1e-5
use_rope: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
# Set hidden dimension if not provided
if self.hidden_dim is None:
self.actual_hidden_dim = 4 * self.dim
else:
self.actual_hidden_dim = self.hidden_dim
# Layer normalization
self.input_layernorm = nn.LayerNorm(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype,
name="input_layernorm"
)
self.post_attention_layernorm = nn.LayerNorm(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype,
name="post_attention_layernorm"
)
# Attention
if self.use_rope:
self.attention = RotaryMultiQueryAttention(
dim=self.dim,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
max_seq_len=self.max_seq_len,
dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
name="attention"
)
else:
self.attention = MultiQueryAttention(
dim=self.dim,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
name="attention"
)
# Feed-forward network
self.feed_forward = FeedForward(
dim=self.dim,
hidden_dim=self.actual_hidden_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
name="feed_forward"
)
# Dropout
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply transformer layer.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len]
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
# Self-attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_outputs = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
deterministic=deterministic,
)
hidden_states = attention_outputs[0]
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
# Feed-forward network
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.feed_forward(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
outputs = (hidden_states,) + attention_outputs[1:]
return outputs
|