|
""" |
|
Applies the mish function element-wise: |
|
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
|
""" |
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
@torch.jit.script |
|
def mish(input): |
|
""" |
|
Applies the mish function element-wise: |
|
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
|
See additional documentation for mish class. |
|
""" |
|
return input * torch.tanh(F.softplus(input)) |
|
|
|
class Mish(nn.Module): |
|
""" |
|
Applies the mish function element-wise: |
|
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
|
|
|
Shape: |
|
- Input: (N, *) where * means, any number of additional |
|
dimensions |
|
- Output: (N, *), same shape as the input |
|
|
|
Examples: |
|
>>> m = Mish() |
|
>>> input = torch.randn(2) |
|
>>> output = m(input) |
|
|
|
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html |
|
""" |
|
|
|
def __init__(self): |
|
""" |
|
Init method. |
|
""" |
|
super().__init__() |
|
|
|
def forward(self, input): |
|
""" |
|
Forward pass of the function. |
|
""" |
|
if torch.__version__ >= "1.9": |
|
return F.mish(input) |
|
else: |
|
return mish(input) |