Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,024 Bytes
d620330 3b34fd9 d620330 6904d6f fa9f7fb 50b34b7 d620330 d1e1697 d620330 d1e1697 d620330 37a7540 d620330 37a7540 3f59c7c 4c52a93 d3f84d6 d91d40e d620330 17af049 90f50bb d620330 4c52a93 d620330 d3f84d6 d620330 3b34fd9 d3f84d6 d620330 fbd8807 d620330 ec3cf95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
from backtrack_sampler.provider.transformers_provider import TransformersProvider
import torch
import spaces
import asyncio
description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sampler with Creative Writing Strategy
This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
<br />On the left is the output of the standard sampler and on the right the output privided by Backtrack Sampler.
"""
model_name = "unsloth/Llama-3.2-1B-Instruct"
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
model2 = AutoModelForCausalLM.from_pretrained(model_name)
provider = TransformersProvider(model2, tokenizer, device)
strategy = CreativeWritingStrategy(provider,
top_p_flat = 0.65,
top_k_threshold_flat = 9,
eos_penalty = 0.75)
creative_sampler = BacktrackSampler(provider, strategy)
def create_chat_template_messages(history, prompt):
messages = [{"role": "user", "content": prompt}]
for i, (input_text, response_text) in enumerate(history):
messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
messages.append({"role": "assistant", "content": response_text})
return messages
@spaces.GPU(duration=60)
def generate_responses(prompt, history):
messages = create_chat_template_messages(history, prompt)
wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#it already has special tokens from wrapped_prompt
inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
async def custom_sampler_task():
generated_list = []
generator = creative_sampler.generate(wrapped_prompt, max_new_tokens=1024, temperature=1)
for token in generator:
generated_list.append(token)
return tokenizer.decode(generated_list, skip_special_tokens=True)
custom_output = asyncio.run(custom_sampler_task())
standard_output = model1.generate(inputs, max_new_tokens=1024, temperature=1)
standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
return standard_response.strip(), custom_output.strip()
with gr.Blocks(theme=gr.themes.Citrus()) as demo:
gr.Markdown(description)
with gr.Row():
standard_chat = gr.Chatbot(label="Standard Sampler")
custom_chat = gr.Chatbot(label="Creative Writing Strategy")
with gr.Row():
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
examples = [
"Write me a short story about a talking dog who wants to be a detective.",
"Tell me a short tale of a dragon who is afraid of heights.",
"Create a short story where aliens land on Earth, but they just want to throw a party."
]
gr.Examples(examples=examples, inputs=prompt_input)
submit_button = gr.Button("Submit")
def update_chat(prompt, standard_history, custom_history):
standard_response, custom_response = generate_responses(prompt, standard_history)
standard_history = standard_history + [(prompt, standard_response)]
custom_history = custom_history + [(prompt, custom_response)]
return standard_history, custom_history, ""
prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
demo.queue().launch(debug=True)
|