Spaces:
Runtime error
Runtime error
| """ | |
| Gradio interface for TinyStories Llama model chat. | |
| """ | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import warnings | |
| import os | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/tinystories-llama-15m") | |
| print(f"Loading model and tokenizer from {MODEL_REPO}...") | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_REPO) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| print(f"Model loaded on {device}") | |
| print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| def generate_story( | |
| prompt, | |
| max_length=200, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| do_sample=True | |
| ): | |
| """Generate a story continuation from the prompt.""" | |
| if not prompt.strip(): | |
| return "Please provide a story prompt!" | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode and return | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_text | |
| with gr.Blocks(title="TinyStories Story Generator") as demo: | |
| gr.Markdown( | |
| """ | |
| # TinyStories Llama Model Chat | |
| This is a small Llama-architecture model trained on the TinyStories dataset. | |
| It generates simple, coherent children's stories using vocabulary that a typical 3-4 year old would understand. | |
| **Try starting your story with:** | |
| - "Once upon a time, there was a..." | |
| - "One day, a little boy named..." | |
| - "In a small town, there lived a..." | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Story Prompt", | |
| placeholder="Once upon a time, there was a", | |
| lines=3 | |
| ) | |
| with gr.Accordion("Generation Settings", open=False): | |
| max_length_slider = gr.Slider( | |
| minimum=50, | |
| maximum=256, | |
| value=200, | |
| step=10, | |
| label="Max Length (tokens)" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature (higher = more creative)" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top-k" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ) | |
| do_sample_checkbox = gr.Checkbox( | |
| label="Use Sampling", | |
| value=True | |
| ) | |
| generate_btn = gr.Button("Generate Story", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Story", | |
| lines=15, | |
| show_copy_button=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Once upon a time, there was a little girl named Lily."], | |
| ["One day, a little boy found a magic"], | |
| ["The little dog was very happy because"], | |
| ["In a small garden, there lived a"], | |
| ["Timmy wanted to play with his friend, but"], | |
| ], | |
| inputs=prompt_input, | |
| label="Example Prompts" | |
| ) | |
| generate_btn.click( | |
| fn=generate_story, | |
| inputs=[ | |
| prompt_input, | |
| max_length_slider, | |
| temperature_slider, | |
| top_k_slider, | |
| top_p_slider, | |
| do_sample_checkbox | |
| ], | |
| outputs=output_text | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |