# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable


class CustomGLU(nn.Module):
    """Custom Gated Linear Unit activation.
    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
    function (i.e. sigmoid, swish, etc.).

    Args:
        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
        dim (int): the dimension on which to split the input. Default: -1

    Shape:
        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`

    Examples::
        >>> m = CustomGLU(nn.Sigmoid())
        >>> input = torch.randn(4, 2)
        >>> output = m(input)
    """
    def __init__(self, activation: nn.Module, dim: int = -1):
        super(CustomGLU, self).__init__()
        self.dim = dim
        self.activation = activation

    def forward(self, x: Tensor):
        assert x.shape[self.dim] % 2 == 0  # M = N / 2
        a, b = torch.chunk(x, 2, dim=self.dim)
        return a * self.activation(b)


class SwiGLU(CustomGLU):
    """SiLU Gated Linear Unit activation.
    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(SwiGLU, self).__init__(nn.SiLU(), dim)


class GeGLU(CustomGLU):
    """GeLU Gated Linear Unit activation.
    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(GeGLU, self).__init__(nn.GELU(), dim)


class ReGLU(CustomGLU):
    """ReLU Gated Linear Unit activation.
    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(ReGLU, self).__init__(nn.ReLU(), dim)


def get_activation_fn(
    activation: Union[str, Callable[[Tensor], Tensor]]
) -> Union[str, Callable[[Tensor], Tensor]]:
    """Helper function to map an activation string to the activation class.
    If the supplied activation is not a string that is recognized, the activation is passed back.

    Args:
        activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
    """
    if isinstance(activation, str):
        if activation == "reglu":
            return ReGLU()
        elif activation == "geglu":
            return GeGLU()
        elif activation == "swiglu":
            return SwiGLU()
    return activation