gpt-neo / models /activations.py
aliabd
full working demo
c6e7238
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
import random
BASE_FNS = {'gelu': mtf.gelu,
'relu': mtf.relu,
'sigmoid': mtf.sigmoid,
'tanh': mtf.tanh,
'selu': mtf.selu,
'elu': mtf.elu,
'abs': mtf.abs,
'sin': mtf.sin,
'cos': mtf.cos,
'sign': mtf.sign,
'silu': mtf.swish,
'softplus': mtf.softplus
}
def _arcsinh(x):
return mtf.log(x + mtf.sqrt(1 + x ** 2))
def _var(x, init):
return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [],
initializer=tf.constant_initializer(init), dtype=x.dtype)
def _pos_var(x, val):
return mtf.softplus(_var(x, 0)) + val
def _rrelu(x):
negative_scale = random.random()
return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)
def _elish(x):
cond = mtf.cast(mtf.greater(x, 0), x.dtype)
exp = mtf.exp(x)
return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)
CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01),
'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20),
'id': lambda x: x,
'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49,
'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7,
'spike': lambda x: 1 / (1 + x ** 2),
'spike2': lambda x: mtf.exp(-x ** 2),
'tanhshrink': lambda x: x - tanh(x),
'softsign': lambda x: x / (mtf.abs(x) + 1),
'softmax': lambda x: mtf.softmax(x, x.shape[-1]),
'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]),
'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1,
'rrelu': _rrelu,
'elish': _elish,
'arcsinh': _arcsinh,
'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (
_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))),
'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x),
'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)),
'cosid': lambda x: mtf.cos(x) - x,
'minsin': lambda x: mtf.minimum(x, mtf.sin(x)),
'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)),
'mish': lambda x: x * mtf.tanh(mtf.softplus(x)),
'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)),
'lisht': lambda x: x * mtf.tanh(x),
'seagull': lambda x: mtf.log(1 + x ** 2),
'snake': lambda x: x + mtf.sin(x) ** 2,
'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x),
'softplusmone': lambda x: mtf.softplus(x) - 1
}
def get_activation_fn(params):
if "activation_fn" in params:
activation_fn = params["activation_fn"]
else:
print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)")
activation_fn = "gelu"
if activation_fn in BASE_FNS:
return BASE_FNS[activation_fn]
if activation_fn in CUSTOM_FNS:
return CUSTOM_FNS[activation_fn]
raise ValueError('unknown activation function "activation_fn" in config')