# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines a set of useful activation functions.""" import functools import tensorflow as tf def gelu(input_tensor, approximate=False): """Gaussian Error Linear Unit. Reference: Gaussian Error Linear Units (GELUs), Dan Hendrycks, Kevin Gimpel, arXiv 2016. Args: input_tensor: A tensor with an arbitrary shape. approximate: A boolean, whether to enable approximation. Returns: The activated input tensor. """ return tf.keras.activations.gelu(input_tensor, approximate=approximate) def hard_sigmoid(input_tensor): """Hard sigmoid activation function. Args: input_tensor: A tensor with an arbitrary shape. Returns: The activated input tensor. """ input_tensor = tf.convert_to_tensor(input_tensor) return tf.nn.relu6(input_tensor + tf.constant(3.)) * 0.16667 def relu6(input_tensor): """Relu6 activation function. Args: input_tensor: A tensor with an arbitrary shape. Returns: The activated input tensor. """ input_tensor = tf.convert_to_tensor(input_tensor) return tf.nn.relu6(input_tensor) def swish(input_tensor): """Swish or SiLU activation function. Args: input_tensor: A tensor with an arbitrary shape. Returns: The activated input tensor. """ input_tensor = tf.convert_to_tensor(input_tensor) return tf.nn.silu(input_tensor) def hard_swish(input_tensor): """Hard Swish function. Args: input_tensor: A tensor with an arbitrary shape. Returns: The activated input tensor. """ input_tensor = tf.convert_to_tensor(input_tensor) return input_tensor * tf.nn.relu6( input_tensor + tf.constant(3.)) * (1. / 6.) def identity(input_tensor): """Identity function. Useful for helping in quantization. Args: input_tensor: A tensor with an arbitrary shape. Returns: The activated input tensor. """ input_tensor = tf.convert_to_tensor(input_tensor) return tf.identity(input_tensor) def get_activation(identifier): """Gets activation function via input identifier. This function returns the specified customized activation function, if there is any. Otherwise, tf.keras.activations.get is called. Args: identifier: A string, name of the activation function. Returns: The specified activation function. """ if isinstance(identifier, str): name_to_fn = { 'gelu': functools.partial(gelu, approximate=False), 'approximated_gelu': functools.partial(gelu, approximate=True), 'silu': swish, 'swish': swish, 'hard_swish': hard_swish, 'relu6': relu6, 'hard_sigmoid': hard_sigmoid, 'identity': identity, 'none': identity, } identifier = str(identifier).lower() if identifier in name_to_fn: return name_to_fn[identifier] return tf.keras.activations.get(identifier)