|
import mesh_tensorflow as mtf |
|
import tensorflow.compat.v1 as tf |
|
import random |
|
|
|
BASE_FNS = {'gelu': mtf.gelu, |
|
'relu': mtf.relu, |
|
'sigmoid': mtf.sigmoid, |
|
'tanh': mtf.tanh, |
|
'selu': mtf.selu, |
|
'elu': mtf.elu, |
|
'abs': mtf.abs, |
|
'sin': mtf.sin, |
|
'cos': mtf.cos, |
|
'sign': mtf.sign, |
|
'silu': mtf.swish, |
|
'softplus': mtf.softplus |
|
} |
|
|
|
|
|
def _arcsinh(x): |
|
return mtf.log(x + mtf.sqrt(1 + x ** 2)) |
|
|
|
|
|
def _var(x, init): |
|
return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], |
|
initializer=tf.constant_initializer(init), dtype=x.dtype) |
|
|
|
|
|
def _pos_var(x, val): |
|
return mtf.softplus(_var(x, 0)) + val |
|
|
|
|
|
def _rrelu(x): |
|
negative_scale = random.random() |
|
return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) |
|
|
|
|
|
def _elish(x): |
|
cond = mtf.cast(mtf.greater(x, 0), x.dtype) |
|
exp = mtf.exp(x) |
|
return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) |
|
|
|
|
|
CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01), |
|
'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20), |
|
'id': lambda x: x, |
|
'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49, |
|
'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7, |
|
'spike': lambda x: 1 / (1 + x ** 2), |
|
'spike2': lambda x: mtf.exp(-x ** 2), |
|
'tanhshrink': lambda x: x - tanh(x), |
|
'softsign': lambda x: x / (mtf.abs(x) + 1), |
|
'softmax': lambda x: mtf.softmax(x, x.shape[-1]), |
|
'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]), |
|
'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1, |
|
'rrelu': _rrelu, |
|
'elish': _elish, |
|
'arcsinh': _arcsinh, |
|
'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / ( |
|
_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))), |
|
'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)), |
|
'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)), |
|
'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0), |
|
'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x), |
|
'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)), |
|
'cosid': lambda x: mtf.cos(x) - x, |
|
'minsin': lambda x: mtf.minimum(x, mtf.sin(x)), |
|
'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)), |
|
'mish': lambda x: x * mtf.tanh(mtf.softplus(x)), |
|
'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)), |
|
'lisht': lambda x: x * mtf.tanh(x), |
|
'seagull': lambda x: mtf.log(1 + x ** 2), |
|
'snake': lambda x: x + mtf.sin(x) ** 2, |
|
'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x), |
|
'softplusmone': lambda x: mtf.softplus(x) - 1 |
|
} |
|
|
|
|
|
def get_activation_fn(params): |
|
if "activation_fn" in params: |
|
activation_fn = params["activation_fn"] |
|
else: |
|
print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)") |
|
activation_fn = "gelu" |
|
|
|
if activation_fn in BASE_FNS: |
|
return BASE_FNS[activation_fn] |
|
|
|
if activation_fn in CUSTOM_FNS: |
|
return CUSTOM_FNS[activation_fn] |
|
|
|
raise ValueError('unknown activation function "activation_fn" in config') |
|
|
|
|
|
|
|
|