|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Defines a set of useful activation functions.""" |
|
import functools |
|
import tensorflow as tf |
|
|
|
|
|
def gelu(input_tensor, approximate=False): |
|
"""Gaussian Error Linear Unit. |
|
|
|
Reference: |
|
Gaussian Error Linear Units (GELUs), Dan Hendrycks, Kevin Gimpel, arXiv 2016. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
approximate: A boolean, whether to enable approximation. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
return tf.keras.activations.gelu(input_tensor, approximate=approximate) |
|
|
|
|
|
def hard_sigmoid(input_tensor): |
|
"""Hard sigmoid activation function. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
input_tensor = tf.convert_to_tensor(input_tensor) |
|
return tf.nn.relu6(input_tensor + tf.constant(3.)) * 0.16667 |
|
|
|
|
|
def relu6(input_tensor): |
|
"""Relu6 activation function. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
input_tensor = tf.convert_to_tensor(input_tensor) |
|
return tf.nn.relu6(input_tensor) |
|
|
|
|
|
def swish(input_tensor): |
|
"""Swish or SiLU activation function. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
input_tensor = tf.convert_to_tensor(input_tensor) |
|
return tf.nn.silu(input_tensor) |
|
|
|
|
|
def hard_swish(input_tensor): |
|
"""Hard Swish function. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
input_tensor = tf.convert_to_tensor(input_tensor) |
|
return input_tensor * tf.nn.relu6( |
|
input_tensor + tf.constant(3.)) * (1. / 6.) |
|
|
|
|
|
def identity(input_tensor): |
|
"""Identity function. |
|
|
|
Useful for helping in quantization. |
|
|
|
Args: |
|
input_tensor: A tensor with an arbitrary shape. |
|
|
|
Returns: |
|
The activated input tensor. |
|
""" |
|
input_tensor = tf.convert_to_tensor(input_tensor) |
|
return tf.identity(input_tensor) |
|
|
|
|
|
def get_activation(identifier): |
|
"""Gets activation function via input identifier. |
|
|
|
This function returns the specified customized activation function, if there |
|
is any. Otherwise, tf.keras.activations.get is called. |
|
|
|
Args: |
|
identifier: A string, name of the activation function. |
|
|
|
Returns: |
|
The specified activation function. |
|
""" |
|
if isinstance(identifier, str): |
|
name_to_fn = { |
|
'gelu': functools.partial(gelu, approximate=False), |
|
'approximated_gelu': functools.partial(gelu, approximate=True), |
|
'silu': swish, |
|
'swish': swish, |
|
'hard_swish': hard_swish, |
|
'relu6': relu6, |
|
'hard_sigmoid': hard_sigmoid, |
|
'identity': identity, |
|
'none': identity, |
|
} |
|
identifier = str(identifier).lower() |
|
if identifier in name_to_fn: |
|
return name_to_fn[identifier] |
|
return tf.keras.activations.get(identifier) |
|
|