storyGPT: Small Language Model (SLM) for Story Generation

This model is a decoder-only Transformer (~57M parameters) pretrained from scratch on the roneneldan/TinyStories dataset. It is designed to generate coherent, grammatically correct short stories.

Model Configuration

  • Embedding Dimension (n_embd): 512
  • Attention Heads (n_head): 8
  • Transformer Layers (n_layer): 10
  • Context Length: 256
  • Vocabulary Size: 50,257 (GPT-2 encoding)

How to Download and Use Programmatically

You can load this model programmatically anywhere without cloning the entire repository manually.

import os
import sys
import torch
import tiktoken
from huggingface_hub import hf_hub_download

# Define repository and destination
repo_id = "justjuu/story-gpt"
local_dir = "./storyGPT_model"
os.makedirs(local_dir, exist_ok=True)

# 1. Download model definition, config, and weights
hf_hub_download(repo_id=repo_id, filename="model.py", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="config.py", local_dir=local_dir)
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="checkpoints/storyGPT_best.pt", local_dir=local_dir)

# 2. Dynamically import model definition from downloaded directory
sys.path.append(local_dir)
from model import GPT

# 3. Load model weights
device = "cuda" if torch.cuda.is_available() else "cpu"
# Explicitly setting weights_only=False to allow loading custom GPTConfig objects from the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model = GPT(checkpoint['gpt_config'])
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

# 4. Generate story
enc = tiktoken.get_encoding("gpt2")
prompt = "Once upon a time, there was a little boy named Timmy who found a magic key."
context = torch.tensor(enc.encode_ordinary(prompt), dtype=torch.long, device=device).unsqueeze(0)

print("Generating...")
with torch.no_grad():
    out = model.generate(context, max_new_tokens=200, temperature=0.8, top_k=100)
    print(enc.decode(out[0].tolist()))
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train justjuu/story-gpt