File size: 3,846 Bytes
47fe629 a970429 47fe629 a970429 207e0eb 47fe629 207e0eb 47fe629 207e0eb a405953 207e0eb 5d9d006 6863e73 760adf8 6863e73 207e0eb 47fe629 207e0eb 47fe629 a405953 a970429 47fe629 a405953 a970429 a405953 207e0eb 760adf8 207e0eb 47fe629 207e0eb 47fe629 207e0eb a970429 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from collections import namedtuple
import spaces
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
title = """# Minitron Story Generator"""
description = """
# Minitron
Minitron is a family of small language models (SLMs) obtained by pruning [NVIDIA's](https://huggingface.co/nvidia) Nemotron-4 15B model, LLaMA3.1-8B or Mistral NeMO models.
We prune model the number of transformer blocks, embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.
# Short Story Generator
Welcome to the Short Story Generator! This application helps you create unique short stories based on your inputs.
This application will show you the output of several models in the Minitron family. Outputs are shown side by side so you can compare them.
**Instructions:**
1. **Main Character:** Describe the main character of your story. For example, "a brave knight" or "a curious cat".
2. **Setting:** Describe the setting where your story takes place. For example, "in an enchanted forest" or "in a bustling city".
3. **Plot Twist:** Add an interesting plot twist to make the story exciting. For example, "discovers a hidden treasure" or "finds a secret portal to another world".
After filling in these details, click the "Submit" button, and a short story will be generated for you.
"""
inputs = [
gr.Textbox(label="Main Character", placeholder="e.g. a brave knight"),
gr.Textbox(label="Setting", placeholder="e.g. in an enchanted forest"),
gr.Textbox(label="Plot Twist", placeholder="e.g. discovers a hidden treasure"),
gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
]
Model = namedtuple('Model', ['name', 'llm', 'tokenizer'])
model_paths = [
"nvidia/Llama-3.1-Minitron-4B-Width-Base",
"nvidia/Llama-3.1-Minitron-4B-Depth-Base",
"nvidia/Mistral-NeMo-Minitron-8B-Base",
]
device='cuda'
dtype=torch.bfloat16
# Load the tokenizers and models.
models = [
Model(
name=p.split("/")[-1],
llm=AutoModelForCausalLM.from_pretrained(p, torch_dtype=dtype, device_map=device),
tokenizer=AutoTokenizer.from_pretrained(p),
) for p in model_paths
]
outputs = [
gr.Textbox(label=f"Generated Story ({model.name})") for model in models
]
# Define the prompt format
def create_prompt(instruction):
PROMPT = '''Below is an instruction that describes a task.\n\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:'''
return PROMPT.format(instruction=instruction)
@spaces.GPU
def generate_story(character, setting, plot_twist, max_tokens, temperature, top_p):
"""Define the function to generate the story."""
prompt = f"Write a short story with the following details:\nMain character: {character}\nSetting: {setting}\nPlot twist: {plot_twist}\n\nStory:"
output_texts = []
for model in models:
input_ids = model.tokenizer.encode(prompt, return_tensors="pt").to(model.llm.device)
output_ids = model.llm.generate(input_ids, max_length=max_tokens, num_return_sequences=1, temperature=temperature, top_p=top_p)
output_text = model.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_texts.append(output_text[len(prompt):])
return output_texts
# Create the Gradio interface
demo = gr.Interface(
fn=generate_story,
inputs=inputs,
outputs=outputs,
title="Short Story Generator",
description=description
)
if __name__ == "__main__":
demo.launch() |