File size: 2,087 Bytes
d9c2681
a11b14a
72182ee
 
d9c2681
a11b14a
 
d9c2681
 
 
5e8673b
d9c2681
5e8673b
 
2d0a5c6
5e8673b
 
 
 
a11b14a
72182ee
a11b14a
 
 
d9c2681
a11b14a
d9c2681
72182ee
d9c2681
5e8673b
d9c2681
5e8673b
a11b14a
d9c2681
299f700
 
 
 
d9c2681
5e8673b
d9c2681
 
 
 
 
 
 
72182ee
d9c2681
 
72182ee
d9c2681
 
72182ee
d9c2681
72182ee
 
b45ca52
 
d9c2681
72182ee
 
 
d9c2681
299f700
 
5e8673b
 
72182ee
d9c2681
72182ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from typing import Optional


class FrozenLMConfig(PretrainedConfig):
    """
    Configuration class for FrozenLM MLP model.
    """
    model_type = "frozen_lm_mlp"
    
    def __init__(
        self,
        input_dim: int = 256,
        hidden_dim: int = 128,
        output_dim: int = 24576,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim


class FrozenLMModel(PreTrainedModel):
    """
    Simple MLP model: input_dim -> hidden_dim -> ReLU -> output_dim
    """
    config_class = FrozenLMConfig
    
    def __init__(self, config: FrozenLMConfig):
        super().__init__(config)
        
        # MLP layers - using original names to match checkpoint
        self.fc1 = nn.Linear(config.input_dim, config.hidden_dim)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(config.hidden_dim, config.output_dim)
        
        self.post_init()
    
    def forward(
        self, 
        input_ids: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        """
        Forward pass through MLP.
        
        Args:
            input_ids: Input tensor of shape (batch_size, input_dim) or (batch_size, seq_len, input_dim)
            
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, output_dim)
        """
        if input_ids is None:
            raise ValueError("input_ids must be provided")

        input_ids = input_ids.float()
        
        # If input has sequence dimension, pool by averaging
        if input_ids.dim() == 3:  # (batch_size, seq_len, input_dim)
            input_ids = input_ids.mean(dim=1)  # (batch_size, input_dim)
        
        # Forward through MLP - using original layer names
        return self.fc2(self.act(self.fc1(input_ids)))


# Register for AutoConfig and AutoModel
FrozenLMConfig.register_for_auto_class()
FrozenLMModel.register_for_auto_class("AutoModel")