Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from modules.cond import cast | |
| class GEGLU(nn.Module): | |
| """#### Class representing the GEGLU activation function. | |
| GEGLU is a gated activation function that is a combination of GELU and ReLU, | |
| used to fire the neurons in the network. | |
| #### Args: | |
| - `dim_in` (int): The input dimension. | |
| - `dim_out` (int): The output dimension. | |
| """ | |
| def __init__(self, dim_in: int, dim_out: int): | |
| super().__init__() | |
| self.proj = cast.manual_cast.Linear(dim_in, dim_out * 2) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the GEGLU activation function. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| x, gate = self.proj(x).chunk(2, dim=-1) | |
| return x * torch.nn.functional.gelu(gate) | |