ZequnZ commited on
Commit
8270dc8
1 Parent(s): 9028733
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+ import gradio as gr
3
+ import random
4
+ import time
5
+
6
+ from text_generation import Client
7
+
8
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
9
+
10
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
11
+ HF_TOKEN = "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"
12
+ SYSTEM_PROMPT = "I want you to act as a great assistant. You will provide trustful information and can inspire me to think more using supportive languages."
13
+
14
+ client = Client(
15
+ API_URL,
16
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
17
+ )
18
+ EOS_STRING = "</s>"
19
+ EOT_STRING = "<EOT>"
20
+
21
+ generate_kwargs = dict(
22
+ max_new_tokens=50,
23
+ do_sample=True,
24
+ top_p=0.9,
25
+ top_k=20,
26
+ temperature=0.6,
27
+ )
28
+
29
+
30
+ def generate_prompts(
31
+ sys_prompt: str, input: str, history: list[tuple[str, str]]
32
+ ) -> str:
33
+ prompt = f"<s>[INST] {sys_prompt} [/INST]</s>\n\n"
34
+ context = ""
35
+ for user_input, model_output in history:
36
+ # prompt+=f"[INST]{input} {model_output}[/INST]"
37
+ # prompt+=f"[User input]{user_input} [Model output]{model_output}\n\n"
38
+ if user_input != "":
39
+ context += f"{user_input}:\n{model_output}\n"
40
+ if context != "":
41
+ prompt += "[INST] Below are some Context between me and you, which can be used as reference to answer [Next user input] and stop when finishing answering:\n"
42
+ prompt += context
43
+ prompt += f"[/INST]\n\n[Next user input]:\n\n"
44
+ prompt += f"{input}\n"
45
+ return prompt
46
+
47
+
48
+ # theme = gr.themes.Base()
49
+ theme = "WeixuanYuan/Soft_dark"
50
+
51
+ with gr.Blocks(theme=theme) as demo:
52
+ gr.Markdown("# Chat with Mistral-7B\n[Github](https://github.com/ZequnZ/Chat-with-Mistral-7B)")
53
+ with gr.Row():
54
+ chatbot = gr.Chatbot(scale=6)
55
+
56
+ with gr.Column(variant="compact", scale=1):
57
+ gr.Markdown("## Parameters:")
58
+ max_new_tokens = gr.Slider(
59
+ label="Max new tokens",
60
+ minimum=1,
61
+ maximum=1024,
62
+ step=1,
63
+ value=128,
64
+ )
65
+ temperature = gr.Slider(
66
+ label="Temperature",
67
+ minimum=0.1,
68
+ maximum=2,
69
+ step=0.1,
70
+ value=0.6,
71
+ )
72
+ top_p = gr.Slider(
73
+ label="Top-p (nucleus sampling)",
74
+ minimum=0.05,
75
+ maximum=1.0,
76
+ step=0.05,
77
+ value=0.9,
78
+ )
79
+ top_k = gr.Slider(
80
+ label="Top-k",
81
+ minimum=1,
82
+ maximum=100,
83
+ step=1,
84
+ value=10,
85
+ )
86
+
87
+ with gr.Row():
88
+ textbox = gr.Textbox(
89
+ show_label=False,
90
+ placeholder="What do you wanna ask?",
91
+ scale=10,
92
+ )
93
+ submit_bt = gr.Button("✔️ Submit", scale=1, variant=1)
94
+ with gr.Row():
95
+ clear_bt = gr.Button("🗑️ Clear")
96
+ remove_bt = gr.Button("← Remove last input")
97
+ retry_bt = gr.Button("🔄 Retry")
98
+
99
+ system_prompt = gr.Textbox(
100
+ label="System prompt/Instruction",
101
+ value=SYSTEM_PROMPT,
102
+ lines=3,
103
+ interactive=True,
104
+ )
105
+
106
+ # Submit the message in textbox
107
+ def sub_msg(user_message, history) -> tuple[str, list[tuple[str, str]]]:
108
+ if not history == None:
109
+ return "", history + [[user_message, None]]
110
+ else:
111
+ return "", [[user_message, None]]
112
+
113
+ def remove_last_dialogue(history: list[tuple[str, str]]) -> list[tuple[str, str]]:
114
+ if history:
115
+ history.pop()
116
+ return history
117
+
118
+ def remove_last_output(history: list[tuple[str, str]]) -> list[tuple[str, str]]:
119
+ if history:
120
+ last_dialogue = history.pop()
121
+ history.append([last_dialogue[0], None])
122
+ return history
123
+
124
+ def output_messages(history: list[tuple[str, str]]) -> list[tuple[str, str]]:
125
+ return history
126
+
127
+ def bot(history: list[tuple[str, str]]) -> Iterator[list[tuple[str, str]]]:
128
+ bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
129
+ history[-1][1] = ""
130
+ for character in bot_message:
131
+ history[-1][1] += character
132
+ time.sleep(0.05)
133
+ yield history
134
+
135
+ def call_llm(
136
+ history: list[tuple[str, str]],
137
+ max_new_tokens: int,
138
+ temperature: float,
139
+ top_p: float,
140
+ top_k: float,
141
+ sys_prompt: str,
142
+ ) -> Iterator[list[tuple[str, str]]]:
143
+ generate_kwargs = dict(
144
+ do_sample=True,
145
+ max_new_tokens=max_new_tokens,
146
+ top_p=top_p,
147
+ top_k=top_k,
148
+ temperature=temperature,
149
+ )
150
+ if history:
151
+ prompt = generate_prompts(sys_prompt, history[-1][0], history[:-1])
152
+ history[-1][1] = ""
153
+ print("prompt: ", prompt)
154
+
155
+ stream = client.generate_stream(prompt, **generate_kwargs)
156
+ time.sleep(3)
157
+
158
+ for response in stream:
159
+ if response.token.text != EOS_STRING:
160
+ history[-1][1] += response.token.text
161
+ time.sleep(0.05)
162
+ yield history
163
+ return []
164
+
165
+ textbox.submit(sub_msg, [textbox, chatbot], [textbox, chatbot], queue=False).then(
166
+ fn=call_llm,
167
+ inputs=[chatbot, max_new_tokens, temperature, top_p, top_k, system_prompt],
168
+ outputs=chatbot,
169
+ )
170
+ submit_bt.click(
171
+ sub_msg, [textbox, chatbot], [textbox, chatbot], queue=False, show_progress=True
172
+ ).then(
173
+ fn=call_llm,
174
+ inputs=[chatbot, max_new_tokens, temperature, top_p, top_k, system_prompt],
175
+ outputs=chatbot,
176
+ )
177
+
178
+ # CLear all the history
179
+ clear_bt.click(lambda: None, None, chatbot, queue=False)
180
+
181
+ remove_bt.click(remove_last_dialogue, [chatbot], [chatbot], queue=False).then(
182
+ output_messages, chatbot, chatbot
183
+ )
184
+
185
+ retry_bt.click(
186
+ fn=remove_last_output, inputs=[chatbot], outputs=[chatbot], queue=False
187
+ ).then(
188
+ fn=call_llm,
189
+ inputs=[chatbot, max_new_tokens, temperature, top_p, top_k, system_prompt],
190
+ outputs=chatbot,
191
+ )
192
+
193
+
194
+ if __name__ == "__main__":
195
+ demo.launch()