gpt2-query2sae / modeling_frozen_lm_mlp.py
mksethi's picture
Update modeling_frozen_lm_mlp.py
b45ca52 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Optional
class FrozenLMConfig(PretrainedConfig):
"""
Configuration class for FrozenLM MLP model.
"""
model_type = "frozen_lm_mlp"
def __init__(
self,
input_dim: int = 256,
hidden_dim: int = 128,
output_dim: int = 24576,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
class FrozenLMModel(PreTrainedModel):
"""
Simple MLP model: input_dim -> hidden_dim -> ReLU -> output_dim
"""
config_class = FrozenLMConfig
def __init__(self, config: FrozenLMConfig):
super().__init__(config)
# MLP layers - using original names to match checkpoint
self.fc1 = nn.Linear(config.input_dim, config.hidden_dim)
self.act = nn.ReLU()
self.fc2 = nn.Linear(config.hidden_dim, config.output_dim)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""
Forward pass through MLP.
Args:
input_ids: Input tensor of shape (batch_size, input_dim) or (batch_size, seq_len, input_dim)
Returns:
torch.Tensor: Output tensor of shape (batch_size, output_dim)
"""
if input_ids is None:
raise ValueError("input_ids must be provided")
input_ids = input_ids.float()
# If input has sequence dimension, pool by averaging
if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)
input_ids = input_ids.mean(dim=1) # (batch_size, input_dim)
# Forward through MLP - using original layer names
return self.fc2(self.act(self.fc1(input_ids)))
# Register for AutoConfig and AutoModel
FrozenLMConfig.register_for_auto_class()
FrozenLMModel.register_for_auto_class("AutoModel")