|
import gradio as gr |
|
import re |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_name_or_path = "teknium/OpenHermes-2-Mistral-7B" |
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, |
|
device_map="auto", |
|
trust_remote_code=False, |
|
load_in_8bit=True, |
|
revision="main") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
|
BASE_SYSTEM_MESSAGE = "I carefully provide accurate, factual, thoughtful, nuanced answers and am brilliant at reasoning." |
|
|
|
def make_prediction(prompt, max_tokens=None, temperature=None, top_p=None, top_k=None, repetition_penalty=None): |
|
input_ids = tokenizer.encode(prompt, return_tensors="pt") |
|
out = model.generate(input_ids, max_length=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) |
|
text = tokenizer.decode(out[0], skip_special_tokens=True) |
|
yield text |
|
|
|
def clear_chat(chat_history_state, chat_message): |
|
chat_history_state = [] |
|
chat_message = '' |
|
return chat_history_state, chat_message |
|
|
|
def user(message, history): |
|
history = history or [] |
|
history.append([message, ""]) |
|
return "", history |
|
|
|
def chat(history, system_message, max_tokens, temperature, top_p, top_k, repetition_penalty): |
|
history = history or [] |
|
|
|
|
|
user_prompt = history[-1][0] if history else "" |
|
|
|
|
|
prompt_template = f'''system |
|
{system_message.strip()} |
|
user |
|
{user_prompt} |
|
assistant |
|
''' |
|
input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda() |
|
|
|
|
|
output = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty |
|
) |
|
|
|
|
|
decoded_output = tokenizer.decode(output[0]) |
|
assistant_response = decoded_output.split('assistant')[-1].strip() |
|
|
|
|
|
if history: |
|
history[-1][1] += assistant_response |
|
else: |
|
history.append(["", assistant_response]) |
|
|
|
return history, history, "" |
|
|
|
|
|
start_message = "" |
|
|
|
CSS =""" |
|
.contain { display: flex; flex-direction: column; } |
|
.gradio-container { height: 100vh !important; } |
|
#component-0 { height: 100%; } |
|
#chatbot { flex-grow: 1; overflow: auto; resize: vertical; } |
|
""" |
|
with gr.Blocks(css=CSS) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(""" |
|
## OpenHermes-V2 Finetuned on Mistral 7B |
|
**Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1). Thanks HF for GPU!** |
|
**OpenHermes-V2 is currently SOTA in some benchmarks for 7B models.** |
|
**Hermes 2 model was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks! Hermes 2 changes the game with strong multiturn chat skills, system prompt capabilities, and uses ChatML format. It's quality, diversity and scale is unmatched in the current OS LM landscape. Not only does it do well in benchmarks, but also in unmeasured capabilities, like Roleplaying, Tasks, and more.** |
|
""") |
|
with gr.Row(): |
|
|
|
chatbot = gr.Chatbot(elem_id="chatbot") |
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="What do you want to chat about?", |
|
placeholder="Ask me anything.", |
|
lines=3, |
|
) |
|
with gr.Row(): |
|
submit = gr.Button(value="Send message", variant="secondary").style(full_width=True) |
|
clear = gr.Button(value="New topic", variant="secondary").style(full_width=False) |
|
stop = gr.Button(value="Stop", variant="secondary").style(full_width=False) |
|
with gr.Accordion("Show Model Parameters", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
max_tokens = gr.Slider(20, 2500, label="Max Tokens", step=20, value=500) |
|
temperature = gr.Slider(0.0, 2.0, label="Temperature", step=0.1, value=0.4) |
|
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.95) |
|
top_k = gr.Slider(1, 100, label="Top K", step=1, value=40) |
|
repetition_penalty = gr.Slider(1.0, 2.0, label="Repetition Penalty", step=0.1, value=1.1) |
|
|
|
system_msg = gr.Textbox( |
|
start_message, label="System Message", interactive=True, visible=True, placeholder="System prompt. Provide instructions which you want the model to remember.", lines=5) |
|
|
|
chat_history_state = gr.State() |
|
clear.click(clear_chat, inputs=[chat_history_state, message], outputs=[chat_history_state, message], queue=False) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
submit_click_event = submit.click( |
|
fn=user, inputs=[message, chat_history_state], outputs=[message, chat_history_state], queue=True |
|
).then( |
|
fn=chat, inputs=[chat_history_state, system_msg, max_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[chatbot, chat_history_state, message], queue=True |
|
) |
|
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_click_event], queue=False) |
|
|
|
demo.queue(max_size=128, concurrency_count=2) |