WuChengyue commited on
Commit
ad41ac1
β€’
1 Parent(s): 086c1ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import sys
4
+ import html
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+ from threading import Thread
7
+
8
+ model_name_or_path = 'TencentARC/Mistral_Pro_8B_v0.1'
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
12
+
13
+ model.half().cuda()
14
+
15
+ def convert_message(message):
16
+ message_text = ""
17
+ if message["content"] is None and message["role"] == "assistant":
18
+ message_text += "<|assistant|>\n" # final msg
19
+ elif message["role"] == "system":
20
+ message_text += "<|system|>\n" + message["content"].strip() + "\n"
21
+ elif message["role"] == "user":
22
+ message_text += "<|user|>\n" + message["content"].strip() + "\n"
23
+ elif message["role"] == "assistant":
24
+ message_text += "<|assistant|>\n" + message["content"].strip() + "\n"
25
+ else:
26
+ raise ValueError("Invalid role: {}".format(message["role"]))
27
+ # gradio cleaning - it converts stuff to html entities
28
+ # we would need special handling for where we want to keep the html...
29
+ message_text = html.unescape(message_text)
30
+ # it also converts newlines to <br>, undo this.
31
+ message_text = message_text.replace("<br>", "\n")
32
+ return message_text
33
+
34
+ def convert_history(chat_history, max_input_length=1024):
35
+ history_text = ""
36
+ idx = len(chat_history) - 1
37
+ # add messages in reverse order until we hit max_input_length
38
+ while len(tokenizer(history_text).input_ids) < max_input_length and idx >= 0:
39
+ user_message, chatbot_message = chat_history[idx]
40
+ user_message = convert_message({"role": "user", "content": user_message})
41
+ chatbot_message = convert_message({"role": "assistant", "content": chatbot_message})
42
+ history_text = user_message + chatbot_message + history_text
43
+ idx = idx - 1
44
+ # if nothing was added, add <|assistant|> to start generation.
45
+ if history_text == "":
46
+ history_text = "<|assistant|>\n"
47
+ return history_text
48
+
49
+ @torch.inference_mode()
50
+ def instruct(instruction, max_token_output=1024):
51
+ input_text = instruction
52
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
53
+ input_ids = tokenizer(input_text, return_tensors='pt', truncation=False)
54
+ input_ids["input_ids"] = input_ids["input_ids"].cuda()
55
+ input_ids["attention_mask"] = input_ids["attention_mask"].cuda()
56
+ generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output, do_sample=False)
57
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
58
+ thread.start()
59
+ return streamer
60
+
61
+
62
+ with gr.Blocks() as demo:
63
+ # chatbot-style model
64
+ with gr.Tab("Chatbot"):
65
+ chatbot = gr.Chatbot([], elem_id="chatbot")
66
+ msg = gr.Textbox()
67
+ clear = gr.Button("Clear")
68
+ # fn to add user message to history
69
+ def user(user_message, history):
70
+ return "", history + [[user_message, None]]
71
+
72
+ def bot(history):
73
+ prompt = convert_history(history)
74
+ streaming_out = instruct(prompt)
75
+ history[-1][1] = ""
76
+ for new_token in streaming_out:
77
+ history[-1][1] += new_token
78
+ yield history
79
+
80
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
81
+ bot, chatbot, chatbot
82
+ )
83
+
84
+ clear.click(lambda: None, None, chatbot, queue=False)
85
+
86
+ if __name__ == "__main__":
87
+ demo.queue().launch(share=True)