deeplab2 / model /layers /activations.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a set of useful activation functions."""
import functools
import tensorflow as tf
def gelu(input_tensor, approximate=False):
"""Gaussian Error Linear Unit.
Reference:
Gaussian Error Linear Units (GELUs), Dan Hendrycks, Kevin Gimpel, arXiv 2016.
Args:
input_tensor: A tensor with an arbitrary shape.
approximate: A boolean, whether to enable approximation.
Returns:
The activated input tensor.
"""
return tf.keras.activations.gelu(input_tensor, approximate=approximate)
def hard_sigmoid(input_tensor):
"""Hard sigmoid activation function.
Args:
input_tensor: A tensor with an arbitrary shape.
Returns:
The activated input tensor.
"""
input_tensor = tf.convert_to_tensor(input_tensor)
return tf.nn.relu6(input_tensor + tf.constant(3.)) * 0.16667
def relu6(input_tensor):
"""Relu6 activation function.
Args:
input_tensor: A tensor with an arbitrary shape.
Returns:
The activated input tensor.
"""
input_tensor = tf.convert_to_tensor(input_tensor)
return tf.nn.relu6(input_tensor)
def swish(input_tensor):
"""Swish or SiLU activation function.
Args:
input_tensor: A tensor with an arbitrary shape.
Returns:
The activated input tensor.
"""
input_tensor = tf.convert_to_tensor(input_tensor)
return tf.nn.silu(input_tensor)
def hard_swish(input_tensor):
"""Hard Swish function.
Args:
input_tensor: A tensor with an arbitrary shape.
Returns:
The activated input tensor.
"""
input_tensor = tf.convert_to_tensor(input_tensor)
return input_tensor * tf.nn.relu6(
input_tensor + tf.constant(3.)) * (1. / 6.)
def identity(input_tensor):
"""Identity function.
Useful for helping in quantization.
Args:
input_tensor: A tensor with an arbitrary shape.
Returns:
The activated input tensor.
"""
input_tensor = tf.convert_to_tensor(input_tensor)
return tf.identity(input_tensor)
def get_activation(identifier):
"""Gets activation function via input identifier.
This function returns the specified customized activation function, if there
is any. Otherwise, tf.keras.activations.get is called.
Args:
identifier: A string, name of the activation function.
Returns:
The specified activation function.
"""
if isinstance(identifier, str):
name_to_fn = {
'gelu': functools.partial(gelu, approximate=False),
'approximated_gelu': functools.partial(gelu, approximate=True),
'silu': swish,
'swish': swish,
'hard_swish': hard_swish,
'relu6': relu6,
'hard_sigmoid': hard_sigmoid,
'identity': identity,
'none': identity,
}
identifier = str(identifier).lower()
if identifier in name_to_fn:
return name_to_fn[identifier]
return tf.keras.activations.get(identifier)