import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from typing import Optional class FrozenLMConfig(PretrainedConfig): """ Configuration class for FrozenLM MLP model. """ model_type = "frozen_lm_mlp" def __init__( self, input_dim: int = 256, hidden_dim: int = 128, output_dim: int = 24576, **kwargs, ): super().__init__(**kwargs) self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim class FrozenLMModel(PreTrainedModel): """ Simple MLP model: input_dim -> hidden_dim -> ReLU -> output_dim """ config_class = FrozenLMConfig def __init__(self, config: FrozenLMConfig): super().__init__(config) # MLP layers - using original names to match checkpoint self.fc1 = nn.Linear(config.input_dim, config.hidden_dim) self.act = nn.ReLU() self.fc2 = nn.Linear(config.hidden_dim, config.output_dim) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Forward pass through MLP. Args: input_ids: Input tensor of shape (batch_size, input_dim) or (batch_size, seq_len, input_dim) Returns: torch.Tensor: Output tensor of shape (batch_size, output_dim) """ if input_ids is None: raise ValueError("input_ids must be provided") input_ids = input_ids.float() # If input has sequence dimension, pool by averaging if input_ids.dim() == 3: # (batch_size, seq_len, input_dim) input_ids = input_ids.mean(dim=1) # (batch_size, input_dim) # Forward through MLP - using original layer names return self.fc2(self.act(self.fc1(input_ids))) # Register for AutoConfig and AutoModel FrozenLMConfig.register_for_auto_class() FrozenLMModel.register_for_auto_class("AutoModel")