File size: 4,024 Bytes
d620330
 
 
 
 
 
3b34fd9
d620330
6904d6f
fa9f7fb
50b34b7
d620330
d1e1697
d620330
d1e1697
d620330
 
37a7540
d620330
37a7540
 
 
3f59c7c
4c52a93
 
d3f84d6
d91d40e
d620330
 
 
 
 
 
 
 
 
 
17af049
90f50bb
d620330
4c52a93
 
 
d620330
 
 
d3f84d6
d620330
 
 
3b34fd9
 
d3f84d6
d620330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd8807
 
d620330
 
 
 
 
 
 
 
 
ec3cf95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import spaces
import asyncio

description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
<br />On the left is the output of the standard sampler and on the right the output privided by Backtrack Sampler.
"""

model_name = "unsloth/Llama-3.2-1B-Instruct"
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)

model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")

model2 = AutoModelForCausalLM.from_pretrained(model_name)

provider = TransformersProvider(model2, tokenizer, device)
strategy = CreativeWritingStrategy(provider, 
                                   top_p_flat = 0.65,
                                   top_k_threshold_flat = 9,
                                   eos_penalty = 0.75)
creative_sampler = BacktrackSampler(provider, strategy)

def create_chat_template_messages(history, prompt):
    messages = [{"role": "user", "content": prompt}]
    
    for i, (input_text, response_text) in enumerate(history):
        messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
        messages.append({"role": "assistant", "content": response_text})
    
    return messages

@spaces.GPU(duration=60)
def generate_responses(prompt, history):
    messages = create_chat_template_messages(history, prompt)
    wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    #it already has special tokens from wrapped_prompt
    inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")

    async def custom_sampler_task():
        generated_list = []
        generator = creative_sampler.generate(wrapped_prompt, max_new_tokens=1024, temperature=1)
        for token in generator:
            generated_list.append(token)
        return tokenizer.decode(generated_list, skip_special_tokens=True)
        
    custom_output = asyncio.run(custom_sampler_task())
    standard_output = model1.generate(inputs, max_new_tokens=1024, temperature=1)
    standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)

    return standard_response.strip(), custom_output.strip()

with gr.Blocks(theme=gr.themes.Citrus()) as demo:
    gr.Markdown(description)

    with gr.Row():
        standard_chat = gr.Chatbot(label="Standard Sampler")
        custom_chat = gr.Chatbot(label="Creative Writing Strategy")

    with gr.Row():
        prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)

    examples = [
        "Write me a short story about a talking dog who wants to be a detective.",
        "Tell me a short tale of a dragon who is afraid of heights.",
        "Create a short story where aliens land on Earth, but they just want to throw a party."
    ]

    gr.Examples(examples=examples, inputs=prompt_input)

    submit_button = gr.Button("Submit")

    def update_chat(prompt, standard_history, custom_history):
        standard_response, custom_response = generate_responses(prompt, standard_history)

        standard_history = standard_history + [(prompt, standard_response)]
        custom_history = custom_history + [(prompt, custom_response)]

        return standard_history, custom_history, ""

    prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
    submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])

demo.queue().launch(debug=True)