Sam Dobson
First commit
c514928
"""
Gradio interface for TinyStories Llama model chat.
"""
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import warnings
import os
warnings.filterwarnings('ignore', category=UserWarning)
MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m")
print(f"Loading model and tokenizer from {MODEL_REPO}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
print(f"Model loaded on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
def generate_story(
prompt,
max_length=200,
temperature=0.8,
top_k=50,
top_p=0.9,
do_sample=True
):
"""Generate a story continuation from the prompt."""
if not prompt.strip():
return "Please provide a story prompt!"
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and return
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
with gr.Blocks(title="TinyStories Story Generator") as demo:
gr.Markdown(
"""
# TinyStories Llama Model Chat
This is a small Llama-architecture model trained on the TinyStories dataset.
It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand.
**Try starting your story with:**
- "Once upon a time, there was a..."
- "One day, a little boy named..."
- "In a small town, there lived a..."
"""
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Story Prompt",
placeholder="Once upon a time, there was a",
lines=3
)
with gr.Accordion("Generation Settings", open=False):
max_length_slider = gr.Slider(
minimum=50,
maximum=256,
value=200,
step=10,
label="Max Length (tokens)"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature (higher = more creative)"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)"
)
do_sample_checkbox = gr.Checkbox(
label="Use Sampling",
value=True
)
generate_btn = gr.Button("Generate Story", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Story",
lines=15,
show_copy_button=True
)
gr.Examples(
examples=[
["Once upon a time, there was a little girl named Lily."],
["One day, a little boy found a magic"],
["The little dog was very happy because"],
["In a small garden, there lived a"],
["Timmy wanted to play with his friend, but"],
],
inputs=prompt_input,
label="Example Prompts"
)
generate_btn.click(
fn=generate_story,
inputs=[
prompt_input,
max_length_slider,
temperature_slider,
top_k_slider,
top_p_slider,
do_sample_checkbox
],
outputs=output_text
)
if __name__ == "__main__":
demo.launch()