|
from dataclasses import dataclass |
|
from typing import Union |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from utils import bias_gelu_impl |
|
from mamba_config import MambaConfig |
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, config: MambaConfig, is_expert: bool = False, layer_idx=None |
|
): |
|
super().__init__() |
|
|
|
self.config: MambaConfig = config |
|
self.layer = layer_idx |
|
ffn_hidden_size_1 = self.config.ffn_hidden_size |
|
ffn_hidden_size_2 = self.config.ffn_hidden_size |
|
|
|
|
|
if self.config.gated_linear_unit: |
|
ffn_hidden_size_1 *= 2 |
|
|
|
self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device) |
|
self.linear_fc1.is_expert = is_expert |
|
|
|
if self.config.gated_linear_unit: |
|
|
|
def glu(x): |
|
x = torch.chunk(x, 2, dim=-1) |
|
return self.config.activation_func(x[0]) * x[1] |
|
|
|
self.activation_func = glu |
|
else: |
|
self.activation_func = self.config.activation_func |
|
|
|
self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device) |
|
|
|
def forward(self, hidden_states, inference_params=None): |
|
intermediate = self.linear_fc1(hidden_states) |
|
intermediate = self.activation_func(intermediate) |
|
output = self.linear_fc2(intermediate) |
|
return output |