stack-llama / app.py
lewtun's picture
lewtun HF staff
Use text-generation inference πŸ”₯πŸ”₯
9d55eb4
raw history blame
No virus
8.58 kB
import json
import os
import gradio as gr
# import torch
# from transformers import (AutoModelForCausalLM, AutoTokenizer,
# TextIteratorStreamer, set_seed)
from huggingface_hub import Repository
from text_generation import Client
# from threading import Thread
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
if HF_TOKEN:
repo = Repository(
local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
)
client = Client(
"https://api-inference.huggingface.co/models/trl-lib/llama-se-rl-merged",
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model_id = "trl-lib/llama-se-rl-merged"
# print(f"Loading model: {model_id}")
# if device == "cpu":
# model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
# else:
# model = AutoModelForCausalLM.from_pretrained(
# model_id, device_map="auto", load_in_8bit=True, use_auth_token=HF_TOKEN
# )
# tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
with open(os.path.join("data", "prompts.jsonl"), "a") as f:
json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
f.write("\n")
commit_url = repo.push_to_hub()
# def generate(instruction, temperature=0.9, max_new_tokens=128, top_p=0.95, top_k=100):
# set_seed(42)
# formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
# temperature = float(temperature)
# top_p = float(top_p)
# streamer = TextIteratorStreamer(tokenizer)
# model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device)
# generate_kwargs = dict(
# top_p=top_p,
# temperature=temperature,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_k=top_k,
# eos_token_id=tokenizer.eos_token_id,
# pad_token_id=tokenizer.eos_token_id,
# )
# t = Thread(target=model.generate, kwargs={**dict(model_inputs, streamer=streamer), **generate_kwargs})
# t.start()
# output = ""
# hidden_output = ""
# for new_text in streamer:
# # skip streaming until new text is available
# if len(hidden_output) <= len(formatted_instruction):
# hidden_output += new_text
# continue
# # replace eos token
# # if tokenizer.eos_token in new_text:
# # new_text = new_text.replace(tokenizer.eos_token, "")
# output += new_text
# yield output
# if HF_TOKEN:
# print("Pushing prompt and completion to the Hub")
# save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
# return output
def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100):
# set_seed(42)
formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
temperature = float(temperature)
top_p = float(top_p)
stream = client.generate_stream(
formatted_instruction,
temperature=temperature,
truncate=999,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
# stop_sequences=["</s>"],
)
output = ""
for response in stream:
output += response.token.text
yield output
return output
# streamer = TextIteratorStreamer(tokenizer)
# model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device)
# generate_kwargs = dict(
# top_p=top_p,
# temperature=temperature,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_k=top_k,
# # eos_token_id=tokenizer.eos_token_id,
# # pad_token_id=tokenizer.eos_token_id,
# )
# t = Thread(target=model.generate, kwargs={**dict(model_inputs, streamer=streamer), **generate_kwargs})
# t.start()
# output = ""
# hidden_output = ""
# for new_text in streamer:
# # skip streaming until new text is available
# if len(hidden_output) <= len(formatted_instruction):
# hidden_output += new_text
# continue
# # replace eos token
# # if tokenizer.eos_token in new_text:
# # new_text = new_text.replace(tokenizer.eos_token, "")
# output += new_text
# yield output
# if HF_TOKEN:
# print("Pushing prompt and completion to the Hub")
# save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
# return output
examples = [
"A llama is in my lawn. How do I get rid of him?",
"How do I create an array in C++ which contains all even numbers between 1 and 10?",
"How can I sort a list in Python?",
"How can I write a Java function to generate the nth Fibonacci number?",
"How many helicopters can a llama eat in one sitting?",
]
def process_example(args):
for x in generate(args):
pass
return x
with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibility: hidden}") as demo:
with gr.Column():
gr.Markdown(
"""<h1><center>πŸ¦™πŸ¦™πŸ¦™ StackLLaMa πŸ¦™πŸ¦™πŸ¦™</center></h1>
StackLLaMa is a 7 billion parameter language model that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
Type in the box below and click the button to generate answers to your most pressing questions πŸ”₯!
**Note:** we are collecting your prompts and model completions for research purposes.
"""
)
with gr.Row():
with gr.Column(scale=3):
instruction = gr.Textbox(placeholder="Enter your question here", label="Question")
with gr.Box():
gr.Markdown("**Answer**")
output = gr.Markdown()
submit = gr.Button("Generate", variant="primary")
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=True,
fn=process_example,
outputs=[output],
)
with gr.Column(scale=1):
temperature = gr.Slider(
label="Temperature",
value=0.8,
minimum=0.01,
maximum=2.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
max_new_tokens = gr.Slider(
label="Max new tokens",
value=128,
minimum=0,
maximum=2048,
step=4,
interactive=True,
info="The maximum numbers of new tokens",
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
)
top_k = gr.Slider(
label="Top-k",
value=40,
minimum=0,
maximum=100,
step=2,
interactive=True,
info="Sample from top-k tokens",
)
submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
demo.queue(concurrency_count=1)
demo.launch(enable_queue=True) # , share=True)