Spaces:
Runtime error
Runtime error
File size: 3,270 Bytes
3ae65e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable
class CustomGLU(nn.Module):
"""Custom Gated Linear Unit activation.
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
function (i.e. sigmoid, swish, etc.).
Args:
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
dim (int): the dimension on which to split the input. Default: -1
Shape:
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
dimensions
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
Examples::
>>> m = CustomGLU(nn.Sigmoid())
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
def __init__(self, activation: nn.Module, dim: int = -1):
super(CustomGLU, self).__init__()
self.dim = dim
self.activation = activation
def forward(self, x: Tensor):
assert x.shape[self.dim] % 2 == 0 # M = N / 2
a, b = torch.chunk(x, 2, dim=self.dim)
return a * self.activation(b)
class SwiGLU(CustomGLU):
"""SiLU Gated Linear Unit activation.
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(SwiGLU, self).__init__(nn.SiLU(), dim)
class GeGLU(CustomGLU):
"""GeLU Gated Linear Unit activation.
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(GeGLU, self).__init__(nn.GELU(), dim)
class ReGLU(CustomGLU):
"""ReLU Gated Linear Unit activation.
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(ReGLU, self).__init__(nn.ReLU(), dim)
def get_activation_fn(
activation: Union[str, Callable[[Tensor], Tensor]]
) -> Union[str, Callable[[Tensor], Tensor]]:
"""Helper function to map an activation string to the activation class.
If the supplied activation is not a string that is recognized, the activation is passed back.
Args:
activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
"""
if isinstance(activation, str):
if activation == "reglu":
return ReGLU()
elif activation == "geglu":
return GeGLU()
elif activation == "swiglu":
return SwiGLU()
return activation
|