SimpleStories-125M / README.md
lennart-finke's picture
Update README.md
ce6dbe3 verified
metadata
language:
  - en
pipeline_tag: text-generation
tags:
  - distillation
  - model_hub_mixin
  - pytorch_model_hub_mixin
  - simple-stories
datasets:
  - lennart-finke/SimpleStories

For loading this model from within https://github.com/danbraunai/simple_stories_train, you can run:

from typing import Any

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from simple_stories_train.models.llama import Llama, LlamaConfig
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT

class LlamaTransformer(
    nn.Module,
    PyTorchModelHubMixin, 
    repo_url="https://github.com/danbraunai/simple_stories_train",
    language=["en"],
    pipeline_tag="text-generation"
):
    def __init__(self, **config : Any):
        super().__init__()
        self.llama = Llama(LlamaConfig(**config))

    def forward(self, x : torch.Tensor):
        return self.llama(x)

config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"

model = model.from_pretrained(HUB_REPO_NAME)