# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch import Tensor from typing import Union, Callable class CustomGLU(nn.Module): """Custom Gated Linear Unit activation. Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation function (i.e. sigmoid, swish, etc.). Args: activation (nn.Module): The custom activation to apply in the Gated Linear Unit dim (int): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = CustomGLU(nn.Sigmoid()) >>> input = torch.randn(4, 2) >>> output = m(input) """ def __init__(self, activation: nn.Module, dim: int = -1): super(CustomGLU, self).__init__() self.dim = dim self.activation = activation def forward(self, x: Tensor): assert x.shape[self.dim] % 2 == 0 # M = N / 2 a, b = torch.chunk(x, 2, dim=self.dim) return a * self.activation(b) class SwiGLU(CustomGLU): """SiLU Gated Linear Unit activation. Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(SwiGLU, self).__init__(nn.SiLU(), dim) class GeGLU(CustomGLU): """GeLU Gated Linear Unit activation. Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(GeGLU, self).__init__(nn.GELU(), dim) class ReGLU(CustomGLU): """ReLU Gated Linear Unit activation. Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(ReGLU, self).__init__(nn.ReLU(), dim) def get_activation_fn( activation: Union[str, Callable[[Tensor], Tensor]] ) -> Union[str, Callable[[Tensor], Tensor]]: """Helper function to map an activation string to the activation class. If the supplied activation is not a string that is recognized, the activation is passed back. Args: activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check """ if isinstance(activation, str): if activation == "reglu": return ReGLU() elif activation == "geglu": return GeGLU() elif activation == "swiglu": return SwiGLU() return activation