""" Gradio interface for TinyStories Llama model chat. """ import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import warnings import os warnings.filterwarnings('ignore', category=UserWarning) MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m") print(f"Loading model and tokenizer from {MODEL_REPO}...") model = AutoModelForCausalLM.from_pretrained(MODEL_REPO) tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() print(f"Model loaded on {device}") print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") def generate_story( prompt, max_length=200, temperature=0.8, top_k=50, top_p=0.9, do_sample=True ): """Generate a story continuation from the prompt.""" if not prompt.strip(): return "Please provide a story prompt!" # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode and return generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text with gr.Blocks(title="TinyStories Story Generator") as demo: gr.Markdown( """ # TinyStories Llama Model Chat This is a small Llama-architecture model trained on the TinyStories dataset. It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand. **Try starting your story with:** - "Once upon a time, there was a..." - "One day, a little boy named..." - "In a small town, there lived a..." """ ) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Story Prompt", placeholder="Once upon a time, there was a", lines=3 ) with gr.Accordion("Generation Settings", open=False): max_length_slider = gr.Slider( minimum=50, maximum=256, value=200, step=10, label="Max Length (tokens)" ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature (higher = more creative)" ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-k" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)" ) do_sample_checkbox = gr.Checkbox( label="Use Sampling", value=True ) generate_btn = gr.Button("Generate Story", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Story", lines=15, show_copy_button=True ) gr.Examples( examples=[ ["Once upon a time, there was a little girl named Lily."], ["One day, a little boy found a magic"], ["The little dog was very happy because"], ["In a small garden, there lived a"], ["Timmy wanted to play with his friend, but"], ], inputs=prompt_input, label="Example Prompts" ) generate_btn.click( fn=generate_story, inputs=[ prompt_input, max_length_slider, temperature_slider, top_k_slider, top_p_slider, do_sample_checkbox ], outputs=output_text ) if __name__ == "__main__": demo.launch()