""" Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) """ # import pytorch import torch import torch.nn.functional as F from torch import nn @torch.jit.script def mish(input): """ Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) See additional documentation for mish class. """ return input * torch.tanh(F.softplus(input)) class Mish(nn.Module): """ Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) Shape: - Input: (N, *) where * means, any number of additional dimensions - Output: (N, *), same shape as the input Examples: >>> m = Mish() >>> input = torch.randn(2) >>> output = m(input) Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html """ def __init__(self): """ Init method. """ super().__init__() def forward(self, input): """ Forward pass of the function. """ if torch.__version__ >= "1.9": return F.mish(input) else: return mish(input)