MiniStoryGPT

MiniStoryGPT is a compact, educational GPT-style language model built in PyTorch to demonstrate training transformer architectures from scratch. It is trained on the TinyStories dataset to generate short, child-friendly narratives. The model follows principles from "Attention is All You Need" and draws inspiration from Andrej Karpathy’s nanoGPT and Zero to Hero materials.

Purpose: This model is designed for educational and experimentation purposes, offering hands-on experience with building, training, and sampling from GPT-like models. It is not intended for production use.

Model Details

  • Architecture: GPT-style transformer with 2 layers, 8 attention heads, and 768 embedding dimensions.
  • Parameters: ~30 million
  • Vocabulary Size: 10,000 tokens (remapped from GPT-2 tokenizer for efficiency)
  • Training Data: TinyStories dataset (preprocessed into train.bin and val.bin)
  • Training: ~50,000 iterations with a batch size of 32, context length of 512, and AdamW optimizer (learning rate 3e-4). Achieved a training loss of ~1.55.
  • Checkpoint: MiniStoryGPT-30M.pth (367MB), saved at iteration 20,000.
  • Hardware: Trained on a single GPU (CUDA-compatible).

Installation

To use MiniStoryGPT, install the required dependencies:

pip install torch tiktoken

Usage

Download the model and mappings from this Hugging Face repository:

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="MiniStoryGPT-30M.pth", local_dir=".")
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="old_to_new.pt", local_dir=".")
hf_hub_download(repo_id="johnnycwatt/MiniStoryGPT", filename="new_to_old.pt", local_dir=".")

Run the provided sampler.py to generate stories:

python sampler.py

Example code to load and generate:

import torch
import tiktoken

# Load model and mappings
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPTLanguageModel().to(device)
model.load_state_dict(torch.load("MiniStoryGPT-30M.pth", map_location=device))
old_to_new = torch.load("old_to_new.pt", map_location=device)
new_to_old = torch.load("new_to_old.pt", map_location=device)
enc = tiktoken.get_encoding("gpt2")

# Remap and generate
prompt = "Once upon a time,"
original_context = enc.encode(prompt)
remapped_context = [old_to_new.get(token, 0) for token in original_context]
context = torch.tensor([remapped_context], dtype=torch.long, device=device)
with torch.no_grad():
    output = model.generate(context, max_new_tokens=300)
    story = enc.decode([new_to_old.get(new_id, 0) for new_id in output[0].tolist()])
    print(story)

The model requires old_to_new.pt and new_to_old.pt for token remapping due to the reduced vocabulary. See the GitHub repository for the full training and sampling code.

Training Details

  • Dataset: TinyStories, preprocessed into tokenized binaries (train.bin, val.bin) with a 10K-token vocabulary.
  • Preprocessing: Uses tiktoken (GPT-2 tokenizer) with custom remapping to reduce vocab size.
  • Hyperparameters:
    • Batch size: 32
    • Context length: 512
    • Learning rate: 3e-4
    • Dropout: 0.2
    • Positional embeddings: Learned (not sinusoidal)
  • Loss: ~1.55 on training set at 50,000 iterations (validation loss ~1.60).

To reproduce training, run prepare_data.py and train.py from the GitHub repo.

Limitations

  • Educational Focus: MiniStoryGPT is for learning, not optimized for production-grade performance.
  • Output Quality: Generates simple, child-friendly stories but may produce incoherent or repetitive text due to its small size and limited training.
  • Vocabulary: Uses a reduced 10K-token vocab, which may miss some nuances of the full GPT-2 tokenizer.
  • Compute: Trained on a single GPU; scaling to larger datasets or models requires more resources.

License

Released under the MIT License. Feel free to use, modify, and distribute for research and educational purposes.

References

Contact

For questions or contributions, open an issue on the GitHub repository or contact johnnycwatt@gmail.com

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train johnnycwatt/MiniStoryGPT

Papers for johnnycwatt/MiniStoryGPT