File size: 2,087 Bytes
d9c2681 a11b14a 72182ee d9c2681 a11b14a d9c2681 5e8673b d9c2681 5e8673b 2d0a5c6 5e8673b a11b14a 72182ee a11b14a d9c2681 a11b14a d9c2681 72182ee d9c2681 5e8673b d9c2681 5e8673b a11b14a d9c2681 299f700 d9c2681 5e8673b d9c2681 72182ee d9c2681 72182ee d9c2681 72182ee d9c2681 72182ee b45ca52 d9c2681 72182ee d9c2681 299f700 5e8673b 72182ee d9c2681 72182ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Optional
class FrozenLMConfig(PretrainedConfig):
"""
Configuration class for FrozenLM MLP model.
"""
model_type = "frozen_lm_mlp"
def __init__(
self,
input_dim: int = 256,
hidden_dim: int = 128,
output_dim: int = 24576,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
class FrozenLMModel(PreTrainedModel):
"""
Simple MLP model: input_dim -> hidden_dim -> ReLU -> output_dim
"""
config_class = FrozenLMConfig
def __init__(self, config: FrozenLMConfig):
super().__init__(config)
# MLP layers - using original names to match checkpoint
self.fc1 = nn.Linear(config.input_dim, config.hidden_dim)
self.act = nn.ReLU()
self.fc2 = nn.Linear(config.hidden_dim, config.output_dim)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""
Forward pass through MLP.
Args:
input_ids: Input tensor of shape (batch_size, input_dim) or (batch_size, seq_len, input_dim)
Returns:
torch.Tensor: Output tensor of shape (batch_size, output_dim)
"""
if input_ids is None:
raise ValueError("input_ids must be provided")
input_ids = input_ids.float()
# If input has sequence dimension, pool by averaging
if input_ids.dim() == 3: # (batch_size, seq_len, input_dim)
input_ids = input_ids.mean(dim=1) # (batch_size, input_dim)
# Forward through MLP - using original layer names
return self.fc2(self.act(self.fc1(input_ids)))
# Register for AutoConfig and AutoModel
FrozenLMConfig.register_for_auto_class()
FrozenLMModel.register_for_auto_class("AutoModel") |