from transformers import PreTrainedModel from .configuration_ss4m import SimpleStories4MConfig from .nano_gpt_model import NanoGPT class SimpleStories4MModel(PreTrainedModel): config_class = SimpleStories4MConfig def __init__(self, config): super().__init__(config) hyperparameters = { "vocab_size": config.vocab_size, "block_size": config.block_size, "n_embed": config.n_embed, "n_heads": config.n_heads, "n_layers": config.n_layers, "dropout": config.dropout, } self.model = NanoGPT(hyperparameters) def forward(self, tensor, targets=None): return self.model(tensor, targets)