File size: 725 Bytes
750522a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import PreTrainedModel
from .configuration_ss4m import SimpleStories4MConfig
from .nano_gpt_model import NanoGPT

class SimpleStories4MModel(PreTrainedModel):
    config_class = SimpleStories4MConfig
    
    def __init__(self, config):
        super().__init__(config)
        hyperparameters = {
            "vocab_size": config.vocab_size,
            "block_size": config.block_size,
            "n_embed": config.n_embed,
            "n_heads": config.n_heads,
            "n_layers": config.n_layers,
            "dropout": config.dropout,
        
        }
        self.model = NanoGPT(hyperparameters)
        
    def forward(self, tensor, targets=None):
        return self.model(tensor, targets)