File size: 1,992 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""
Activation functions for the LLM model.
"""
import jax
import jax.numpy as jnp
from typing import Callable
def gelu(x: jnp.ndarray) -> jnp.ndarray:
"""
Gaussian Error Linear Unit (GELU) activation function.
Args:
x: Input tensor
Returns:
GELU activation applied to input
"""
return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))
def swiglu(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
SwiGLU activation function (Swish-Gated Linear Unit).
Used in modern LLMs like PaLM and Gemini.
Args:
x: First input tensor
y: Second input tensor
Returns:
SwiGLU activation applied to inputs
"""
return x * jax.nn.sigmoid(y)
def relu(x: jnp.ndarray) -> jnp.ndarray:
"""
Rectified Linear Unit (ReLU) activation function.
Args:
x: Input tensor
Returns:
ReLU activation applied to input
"""
return jnp.maximum(0, x)
class GELU:
"""GELU activation function class."""
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return gelu(x)
class SwiGLU:
"""SwiGLU activation function class."""
def __call__(self, x: jnp.ndarray, gate: jnp.ndarray) -> jnp.ndarray:
return swiglu(x, gate)
class ReLU:
"""ReLU activation function class."""
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return relu(x)
def get_activation_fn(name: str) -> Callable:
"""
Get activation function by name.
Args:
name: Name of activation function
Returns:
Activation function
Raises:
ValueError: If activation function is not supported
"""
if name.lower() == 'gelu':
return gelu
elif name.lower() == 'swiglu':
return swiglu
elif name.lower() == 'relu':
return relu
else:
raise ValueError(f"Activation function {name} not supported")
|