MiniStoryGPT
MiniStoryGPT is a compact, educational GPT-style language model built in PyTorch to demonstrate training transformer architectures from scratch. It is trained on the TinyStories dataset to generate short, child-friendly narratives. The model follows principles from "Attention is All You Need" and draws inspiration from Andrej Karpathy’s nanoGPT and Zero to Hero materials.
Purpose: This model is designed for educational and experimentation purposes, offering hands-on experience with building, training, and sampling from GPT-like models. It is not intended for production use.
Model Details
- Architecture: GPT-style transformer with 2 layers, 8 attention heads, and 768 embedding dimensions.
- Parameters: ~30 million
- Vocabulary Size: 10,000 tokens (remapped from GPT-2 tokenizer for efficiency)
- Training Data: TinyStories dataset (preprocessed into
train.binandval.bin) - Training: ~50,000 iterations with a batch size of 32, context length of 512, and AdamW optimizer (learning rate 3e-4). Achieved a training loss of ~1.55.
- Checkpoint:
MiniStoryGPT-30M.pth(367MB), saved at iteration 20,000. - Hardware: Trained on a single GPU (CUDA-compatible).
Installation
To use MiniStoryGPT, install the required dependencies:
pip install torch tiktoken
Usage
Download the model and mappings from this Hugging Face repository:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="MiniStoryGPT-30M.pth", local_dir=".")
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="old_to_new.pt", local_dir=".")
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="new_to_old.pt", local_dir=".")
Run the provided sampler.py to generate stories:
python sampler.py
Example code to load and generate:
import torch
import tiktoken
# Load model and mappings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPTLanguageModel().to(device)
model.load_state_dict(torch.load("MiniStoryGPT-30M.pth", map_location=device))
old_to_new = torch.load("old_to_new.pt", map_location=device)
new_to_old = torch.load("new_to_old.pt", map_location=device)
enc = tiktoken.get_encoding("gpt2")
# Remap and generate
prompt = "Once upon a time,"
original_context = enc.encode(prompt)
remapped_context = [old_to_new.get(token, 0) for token in original_context]
context = torch.tensor([remapped_context], dtype=torch.long, device=device)
with torch.no_grad():
output = model.generate(context, max_new_tokens=300)
story = enc.decode([new_to_old.get(new_id, 0) for new_id in output[0].tolist()])
print(story)
The model requires old_to_new.pt and new_to_old.pt for token remapping due to the reduced vocabulary. See the GitHub repository for the full training and sampling code.
Training Details
- Dataset: TinyStories, preprocessed into tokenized binaries (
train.bin,val.bin) with a 10K-token vocabulary. - Preprocessing: Uses
tiktoken(GPT-2 tokenizer) with custom remapping to reduce vocab size. - Hyperparameters:
- Batch size: 32
- Context length: 512
- Learning rate: 3e-4
- Dropout: 0.2
- Positional embeddings: Learned (not sinusoidal)
- Loss: ~1.55 on training set at 50,000 iterations (validation loss ~1.60).
To reproduce training, run prepare_data.py and train.py from the GitHub repo.
Limitations
- Educational Focus: MiniStoryGPT is for learning, not optimized for production-grade performance.
- Output Quality: Generates simple, child-friendly stories but may produce incoherent or repetitive text due to its small size and limited training.
- Vocabulary: Uses a reduced 10K-token vocab, which may miss some nuances of the full GPT-2 tokenizer.
- Compute: Trained on a single GPU; scaling to larger datasets or models requires more resources.
License
Released under the MIT License. Feel free to use, modify, and distribute for research and educational purposes.
References
- Vaswani et al., "Attention is All You Need" (https://arxiv.org/abs/1706.03762)
- Karpathy’s nanoGPT (https://github.com/karpathy/nanoGPT)
- Karpathy’s Zero to Hero Course (YouTube)
- TinyStories Dataset (https://huggingface.co/datasets/roneneldan/TinyStories)
- TinyStories Paper (https://arxiv.org/abs/2305.07759)
Contact
For questions or contributions, open an issue on the GitHub repository or contact johnnycwatt@gmail.com