File size: 1,992 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Activation functions for the LLM model.
"""

import jax
import jax.numpy as jnp
from typing import Callable


def gelu(x: jnp.ndarray) -> jnp.ndarray:
    """
    Gaussian Error Linear Unit (GELU) activation function.
    
    Args:
        x: Input tensor
        
    Returns:
        GELU activation applied to input
    """
    return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * x**3)))


def swiglu(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """
    SwiGLU activation function (Swish-Gated Linear Unit).
    Used in modern LLMs like PaLM and Gemini.
    
    Args:
        x: First input tensor
        y: Second input tensor
        
    Returns:
        SwiGLU activation applied to inputs
    """
    return x * jax.nn.sigmoid(y)


def relu(x: jnp.ndarray) -> jnp.ndarray:
    """
    Rectified Linear Unit (ReLU) activation function.
    
    Args:
        x: Input tensor
        
    Returns:
        ReLU activation applied to input
    """
    return jnp.maximum(0, x)


class GELU:
    """GELU activation function class."""
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return gelu(x)


class SwiGLU:
    """SwiGLU activation function class."""
    
    def __call__(self, x: jnp.ndarray, gate: jnp.ndarray) -> jnp.ndarray:
        return swiglu(x, gate)


class ReLU:
    """ReLU activation function class."""
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return relu(x)


def get_activation_fn(name: str) -> Callable:
    """
    Get activation function by name.
    
    Args:
        name: Name of activation function
        
    Returns:
        Activation function
        
    Raises:
        ValueError: If activation function is not supported
    """
    if name.lower() == 'gelu':
        return gelu
    elif name.lower() == 'swiglu':
        return swiglu
    elif name.lower() == 'relu':
        return relu
    else:
        raise ValueError(f"Activation function {name} not supported")