artificialguybr's picture
Create app.py
880e945
raw
history blame
No virus
4.96 kB
import os
import gradio as gr
import mdtex2html
import torch
from transformers.generation import GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Initialize model and tokenizer
model_name_or_path = "TheBloke/OpenHermes-2-Mistral-7B-GPTQ"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model.generation_config = GenerationConfig.from_pretrained("TheBloke/OpenHermes-2-Mistral-7B-GPTQ", trust_remote_code=False)
# Postprocess function
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert(message),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
# Text parsing function
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
# Demo launching function
def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history):
print(f"User: {_parse_text(_query)}")
_chatbot.append((_parse_text(_query), ""))
full_response = ""
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
yield _chatbot
full_response = _parse_text(response)
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Qwen-Chat: {_parse_text(full_response)}")
def regenerate(_chatbot, _task_history):
if not _task_history:
yield _chatbot
return
item = _task_history.pop(-1)
_chatbot.pop(-1)
yield from predict(item[0], _chatbot, _task_history)
def reset_user_input():
return gr.update(value="")
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
import gc
gc.collect()
torch.cuda.empty_cache()
return _chatbot
with gr.Blocks() as demo:
gr.Markdown("""
## OpenHermes V2 - Mistral 7B: Mistral 7B Based by Teknium!
**Space created by [@artificialguybr](https://twitter.com/artificialguybr). Model by [@Teknium1](https://twitter.com/Teknium1).Thanks HF for GPU!**
**OpenHermes V2 Mistral 7B was trained on 900,000 instructions, and surpasses all previous versions of Hermes 13B and below, and matches 70B on some benchmarks!**
""")
chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height", queue=True)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
submit_btn = gr.Button("πŸš€ Submit")
empty_btn = gr.Button("🧹 Clear History")
regen_btn = gr.Button("πŸ€”οΈ Regenerate")
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
submit_btn.click(reset_user_input, [], [query], queue=False) #No queue for resetting
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True, queue=False) #No queue for clearing
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
demo.queue(max_size=20)
demo.launch()
# Main execution
if __name__ == "__main__":
_launch_demo(None, model, tokenizer, model.generation_config)