multimodalart's picture
Move performance metrics below
8caf605
raw history blame
No virus
4.77 kB
import os
import gradio as gr
import mdtex2html
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat-int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat-int4", device_map="auto", trust_remote_code=True).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-14B-Chat-int4", trust_remote_code=True)
# Postprocess function
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert(message),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
# Text parsing function
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
# Demo launching function
def _launch_demo(args, model, tokenizer, config):
def predict(_query, _chatbot, _task_history):
print(f"User: {_parse_text(_query)}")
_chatbot.append((_parse_text(_query), ""))
full_response = ""
for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
_chatbot[-1] = (_parse_text(_query), _parse_text(response))
yield _chatbot
full_response = _parse_text(response)
print(f"History: {_task_history}")
_task_history.append((_query, full_response))
print(f"Qwen-Chat: {_parse_text(full_response)}")
def regenerate(_chatbot, _task_history):
if not _task_history:
yield _chatbot
return
item = _task_history.pop(-1)
_chatbot.pop(-1)
yield from predict(item[0], _chatbot, _task_history)
def reset_user_input():
return gr.update(value="")
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
import gc
gc.collect()
torch.cuda.empty_cache()
return _chatbot
with gr.Blocks() as demo:
gr.Markdown("""
## Qwen-14B-Chat: A Large Language Model by Alibaba Cloud
**Space created by [@artificialguybr](https://twitter.com/artificialguybr) based on QWEN Code. Thanks HF for GPU!**
**Qwen is currently SOTA in the benchmarks for 14B models.**
""")
chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height", queue=True)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
empty_btn = gr.Button("🧹 Clear History")
submit_btn = gr.Button("πŸš€ Submit")
regen_btn = gr.Button("πŸ€”οΈ Regenerate")
submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
submit_btn.click(reset_user_input, [], [query])
empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True, queue=True) # Enable queue
gr.Markdown("""### Performance Metrics:
- **MMLU Accuracy**:
- 0-shot: 64.6
- 5-shot: 66.5
- **HumanEval Pass@1**: 43.9
- **GSM8K Accuracy**:
- 0-shot: 60.1
- 8-shot: 59.3
""")
demo.queue(max_size=20)
demo.launch()
# Main execution
if __name__ == "__main__":
_launch_demo(None, model, tokenizer, model.generation_config)