File size: 2,298 Bytes
2b59497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import jax
import jax.numpy as jnp
import flax.linen as nn
activation_function = {
"relu": nn.relu,
"gelu": nn.gelu,
"silu": nn.silu,
"swish": nn.silu,
"tanh": nn.tanh,
"sigmoid": nn.sigmoid,
"softplus": nn.softplus,
"softmax": nn.softmax,
"leaky_relu": nn.leaky_relu,
"elu": nn.elu,
"selu": nn.selu,
"telu": lambda x: x * jnp.tanh(jnp.exp(x)),
"mish": lambda x: x * jnp.tanh(nn.softplus(x)),
"cauchy": lambda x: cauchy()(x),
"identity": lambda x: x,
"react": lambda x: react()(x),
}
# https://arxiv.org/pdf/2503.02267v1
class react(nn.Module):
@nn.compact
def __call__(self, x):
a = self.param(
'a',
jax.nn.initializers.normal(0.1),
()
)
b = self.param(
'b',
jax.nn.initializers.normal(0.1),
()
)
c = self.param(
'c',
jax.nn.initializers.normal(0.1),
()
)
d = self.param(
'd',
jax.nn.initializers.normal(0.1),
()
)
return (1 - jnp.exp(a * x + b)) / (1 + jnp.exp(c * x + d))
# https://arxiv.org/abs/2409.19221
class cauchy(nn.Module):
@nn.compact
def __call__(self, x):
l1 = self.param(
'lambda1',
jax.nn.initializers.constant(1.0),
()
)
l2 = self.param(
'lambda2',
jax.nn.initializers.constant(1.0),
()
)
d = self.param(
'd',
jax.nn.initializers.constant(1.0),
()
)
return l1 * x / (x**2 + d**2) + l2 / (x**2 + d**2)
def get_activation(name):
"""
Get the activation function by name.
Args:
name (str): Name of the activation function.
Returns:
Callable: The activation function.
"""
if name not in activation_function:
raise ValueError(f"Activation function `{name}` is not supported. Supported activations are : {list(activation_function.keys())}")
return activation_function[name]
def list_activations():
"""
List all available activation functions.
Returns:
list: A list of activation names.
"""
return list(activation_function.keys()) |