tpu-optimized-llm / model /activation.py
Threatthriver's picture
Upload folder using huggingface_hub
f24563f verified
"""
Activation functions for the LLM model.
"""
import jax
import jax.numpy as jnp
from typing import Callable
def gelu(x: jnp.ndarray) -> jnp.ndarray:
"""
Gaussian Error Linear Unit (GELU) activation function.
Args:
x: Input tensor
Returns:
GELU activation applied to input
"""
return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))
def swiglu(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""
SwiGLU activation function (Swish-Gated Linear Unit).
Used in modern LLMs like PaLM and Gemini.
Args:
x: First input tensor
y: Second input tensor
Returns:
SwiGLU activation applied to inputs
"""
return x * jax.nn.sigmoid(y)
def relu(x: jnp.ndarray) -> jnp.ndarray:
"""
Rectified Linear Unit (ReLU) activation function.
Args:
x: Input tensor
Returns:
ReLU activation applied to input
"""
return jnp.maximum(0, x)
class GELU:
"""GELU activation function class."""
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return gelu(x)
class SwiGLU:
"""SwiGLU activation function class."""
def __call__(self, x: jnp.ndarray, gate: jnp.ndarray) -> jnp.ndarray:
return swiglu(x, gate)
class ReLU:
"""ReLU activation function class."""
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return relu(x)
def get_activation_fn(name: str) -> Callable:
"""
Get activation function by name.
Args:
name: Name of activation function
Returns:
Activation function
Raises:
ValueError: If activation function is not supported
"""
if name.lower() == 'gelu':
return gelu
elif name.lower() == 'swiglu':
return swiglu
elif name.lower() == 'relu':
return relu
else:
raise ValueError(f"Activation function {name} not supported")