|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from typing import Optional |
|
|
|
|
|
class FrozenLMConfig(PretrainedConfig): |
|
""" |
|
Configuration class for FrozenLM MLP model. |
|
""" |
|
model_type = "frozen_lm_mlp" |
|
|
|
def __init__( |
|
self, |
|
input_dim: int = 256, |
|
hidden_dim: int = 128, |
|
output_dim: int = 24576, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.output_dim = output_dim |
|
|
|
|
|
class FrozenLMModel(PreTrainedModel): |
|
""" |
|
Simple MLP model: input_dim -> hidden_dim -> ReLU -> output_dim |
|
""" |
|
config_class = FrozenLMConfig |
|
|
|
def __init__(self, config: FrozenLMConfig): |
|
super().__init__(config) |
|
|
|
|
|
self.fc1 = nn.Linear(config.input_dim, config.hidden_dim) |
|
self.act = nn.ReLU() |
|
self.fc2 = nn.Linear(config.hidden_dim, config.output_dim) |
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
""" |
|
Forward pass through MLP. |
|
|
|
Args: |
|
input_ids: Input tensor of shape (batch_size, input_dim) or (batch_size, seq_len, input_dim) |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (batch_size, output_dim) |
|
""" |
|
if input_ids is None: |
|
raise ValueError("input_ids must be provided") |
|
|
|
input_ids = input_ids.float() |
|
|
|
|
|
if input_ids.dim() == 3: |
|
input_ids = input_ids.mean(dim=1) |
|
|
|
|
|
return self.fc2(self.act(self.fc1(input_ids))) |
|
|
|
|
|
|
|
FrozenLMConfig.register_for_auto_class() |
|
FrozenLMModel.register_for_auto_class("AutoModel") |