File size: 3,846 Bytes
47fe629
 
a970429
 
 
 
 
47fe629
a970429
207e0eb
 
47fe629
 
207e0eb
 
 
 
47fe629
 
207e0eb
 
 
 
 
 
a405953
 
207e0eb
5d9d006
 
6863e73
760adf8
 
6863e73
207e0eb
 
47fe629
207e0eb
47fe629
 
 
 
 
a405953
a970429
 
47fe629
 
 
 
 
 
 
 
 
 
 
 
 
a405953
a970429
 
 
 
a405953
207e0eb
 
760adf8
207e0eb
 
47fe629
 
 
 
 
 
 
 
207e0eb
47fe629
 
207e0eb
 
 
 
 
 
 
 
 
a970429
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from collections import namedtuple

import spaces
import gradio as gr  
import torch  
from transformers import AutoTokenizer, AutoModelForCausalLM

title = """# Minitron Story Generator"""
description = """
# Minitron

Minitron is a family of small language models (SLMs) obtained by pruning [NVIDIA's](https://huggingface.co/nvidia) Nemotron-4 15B model, LLaMA3.1-8B or Mistral NeMO models.
We prune model the number of transformer blocks, embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.

# Short Story Generator
Welcome to the Short Story Generator! This application helps you create unique short stories based on your inputs.

This application will show you the output of several models in the Minitron family. Outputs are shown side by side so you can compare them.

**Instructions:**
1. **Main Character:** Describe the main character of your story. For example, "a brave knight" or "a curious cat".
2. **Setting:** Describe the setting where your story takes place. For example, "in an enchanted forest" or "in a bustling city".
3. **Plot Twist:** Add an interesting plot twist to make the story exciting. For example, "discovers a hidden treasure" or "finds a secret portal to another world".

After filling in these details, click the "Submit" button, and a short story will be generated for you.
"""

inputs = [
    gr.Textbox(label="Main Character", placeholder="e.g. a brave knight"),
    gr.Textbox(label="Setting", placeholder="e.g. in an enchanted forest"),
    gr.Textbox(label="Plot Twist", placeholder="e.g. discovers a hidden treasure"),
    gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
    gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
    gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
]

Model = namedtuple('Model', ['name', 'llm', 'tokenizer'])

model_paths = [
    "nvidia/Llama-3.1-Minitron-4B-Width-Base",
    "nvidia/Llama-3.1-Minitron-4B-Depth-Base",
    "nvidia/Mistral-NeMo-Minitron-8B-Base",
]

device='cuda'
dtype=torch.bfloat16

# Load the tokenizers and models.
models = [
    Model(
        name=p.split("/")[-1],
        llm=AutoModelForCausalLM.from_pretrained(p, torch_dtype=dtype, device_map=device),
        tokenizer=AutoTokenizer.from_pretrained(p),
    ) for p in model_paths
]

outputs = [
    gr.Textbox(label=f"Generated Story ({model.name})") for model in models
]

# Define the prompt format  
def create_prompt(instruction):  
    PROMPT = '''Below is an instruction that describes a task.\n\nWrite a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:'''  
    return PROMPT.format(instruction=instruction)  


@spaces.GPU 
def generate_story(character, setting, plot_twist, max_tokens, temperature, top_p):
    """Define the function to generate the story."""
    prompt = f"Write a short story with the following details:\nMain character: {character}\nSetting: {setting}\nPlot twist: {plot_twist}\n\nStory:"
    
    output_texts = []
    
    for model in models:         
        input_ids = model.tokenizer.encode(prompt, return_tensors="pt").to(model.llm.device)
        output_ids = model.llm.generate(input_ids, max_length=max_tokens, num_return_sequences=1, temperature=temperature, top_p=top_p)
        output_text = model.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        output_texts.append(output_text[len(prompt):])      
      
    return output_texts 
   

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_story,
    inputs=inputs,
    outputs=outputs,
    title="Short Story Generator",
    description=description
)
  
if __name__ == "__main__":  
    demo.launch()