Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Applies the mish function element-wise: | |
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) | |
""" | |
# import pytorch | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
def mish(input): | |
""" | |
Applies the mish function element-wise: | |
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) | |
See additional documentation for mish class. | |
""" | |
return input * torch.tanh(F.softplus(input)) | |
class Mish(nn.Module): | |
""" | |
Applies the mish function element-wise: | |
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) | |
Shape: | |
- Input: (N, *) where * means, any number of additional | |
dimensions | |
- Output: (N, *), same shape as the input | |
Examples: | |
>>> m = Mish() | |
>>> input = torch.randn(2) | |
>>> output = m(input) | |
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html | |
""" | |
def __init__(self): | |
""" | |
Init method. | |
""" | |
super().__init__() | |
def forward(self, input): | |
""" | |
Forward pass of the function. | |
""" | |
if torch.__version__ >= "1.9": | |
return F.mish(input) | |
else: | |
return mish(input) |