|
--- |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
tags: |
|
- distillation |
|
- model_hub_mixin |
|
- pytorch_model_hub_mixin |
|
- simple-stories |
|
datasets: |
|
- lennart-finke/SimpleStories |
|
--- |
|
For loading this model from within [https://github.com/danbraunai/simple_stories_train](https://github.com/danbraunai/simple_stories_train), you can run: |
|
|
|
```python |
|
from typing import Any |
|
|
|
import torch |
|
import torch.nn as nn |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from simple_stories_train.models.llama import Llama, LlamaConfig |
|
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT |
|
|
|
class LlamaTransformer( |
|
nn.Module, |
|
PyTorchModelHubMixin, |
|
repo_url="https://github.com/danbraunai/simple_stories_train", |
|
language=["en"], |
|
pipeline_tag="text-generation" |
|
): |
|
def __init__(self, **config : Any): |
|
super().__init__() |
|
self.llama = Llama(LlamaConfig(**config)) |
|
|
|
def forward(self, x : torch.Tensor): |
|
return self.llama(x) |
|
|
|
config = MODEL_CONFIGS_DICT["d12"] |
|
model = LlamaTransformer(**config) |
|
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M" |
|
|
|
model = model.from_pretrained(HUB_REPO_NAME) |
|
``` |
|
|
|
- Library: https://github.com/danbraunai/simple_stories_train |