Moses25 commited on
Commit
009c3c7
1 Parent(s): 52db87d
Files changed (1) hide show
  1. chat_llama.py +373 -0
chat_llama.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+
5
+ model_name = "../llama/llama_weight/Llama-2-7b-hf"
6
+ adapters_name = '../ctranslate2/checkpoint/base'
7
+
8
+ print(f"Starting to load the model {model_name} into memory")
9
+
10
+ m = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ # load_in_8bit=True,
13
+ torch_dtype=torch.float16,
14
+ device_map="auto"
15
+ )
16
+ print("finishend load model")
17
+ m = PeftModel.from_pretrained(m, adapters_name)
18
+ m = m.merge_and_unload()
19
+ print("finished merge model")
20
+ tok = LlamaTokenizer.from_pretrained(model_name)
21
+ tok.model_max_length=8192
22
+ # tok.pad_token_id = 0
23
+
24
+ stop_token_ids = [0]
25
+
26
+ print(f"Successfully loaded the model {model_name} into memory")
27
+
28
+
29
+ import datetime
30
+ import os
31
+ from threading import Event, Thread
32
+ from uuid import uuid4
33
+
34
+ import gradio as gr
35
+ import requests
36
+
37
+ max_new_tokens = 1536
38
+ start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
39
+
40
+ ORCA_PROMPT_DICT={"prompt_no_input":(
41
+ "### System:\n"
42
+ "You are an AI assistant that follows instruction extremely well. Help as much as you can."
43
+ "\n\n### User:\n"
44
+ ),
45
+ "prompt_input":(
46
+ "### System:\n"
47
+ "You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\n"
48
+ "### User:\n"
49
+ "{instruction}"
50
+ "\n\n### Input:\n"
51
+ "{input}"
52
+ "\n\n### Response:"
53
+ )}
54
+
55
+ ORCA_PROMPT_DICT={"prompt_no_input":(
56
+ "### System:\n"
57
+ "You are an AI assistant that follows instruction extremely well. Help as much as you can.")
58
+ }
59
+
60
+ PROMPT_DICT = {
61
+ "prompt_input": (
62
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
63
+ "Write a response that appropriately completes the request.\n\n"
64
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
65
+ ),
66
+ "prompt_no_input": (
67
+ "Below is an instruction that describes a task. "
68
+ "Write a response that appropriately completes the request.\n\n"
69
+ "{instruction}\n\n### Response:"
70
+ ),
71
+ }
72
+
73
+
74
+ llama2_prompt ={ "prompt_no_input":"""[INST] <<SYS>>
75
+ You are a helpful, respectful and honest assistant.Help as much as you can.
76
+ <</SYS>>
77
+
78
+ {instruction} [/INST] """}
79
+
80
+ class StopOnTokens(StoppingCriteria):
81
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
82
+ for stop_id in stop_token_ids:
83
+ if input_ids[0][-1] == stop_id:
84
+ return True
85
+ return False
86
+
87
+
88
+ def convert_history_to_text(history):
89
+ if len(history) > 10:
90
+ print("*"*30)
91
+ print("回话超过10轮,重新启动新的会话")
92
+ history = history[10:]
93
+ # text = llama2_prompt['prompt_no_input'] + "".join(
94
+ # [
95
+ # "".join(
96
+ # [
97
+ # # f"### Human: {item[0]}\n",
98
+ # # f"### Assistant: {item[1]}\n",
99
+ # # f"USER: {item[0]}",
100
+ # #ASSISTANT: {item[1]}
101
+ # # f"\n\n### User:\n{item[0]}",
102
+ # # f"\n\n### Response:{item[1]}"
103
+ # # f"### Instruction:\n{item[0]}\n\n",
104
+ # # f"### Response:{item[0]}"
105
+
106
+
107
+
108
+ # ]
109
+ # )
110
+ # for item in history[:-1]
111
+ # ]
112
+ # )
113
+ # text += "".join(
114
+ # [
115
+ # "".join(
116
+ # [
117
+ # # f"### Human: {history[-1][0]}\n",
118
+ # # f"### Assistant: {history[-1][1]}\n",
119
+ # # f"USER: {history[-1][0]}",
120
+ # #"ASSISTANT: {history[-1][1]}"
121
+ # # f"\n\n### User:\n{history[-1][0]}",
122
+ # # f"\n\n### Response:{history[-1][1]}"
123
+ # f"### Instruction:\n{history[-1][0]}\n\n",
124
+ # f"### Response:{history[-1][1]}"
125
+ # ]
126
+ # )
127
+ # ]
128
+ # )
129
+ start_msg = llama2_prompt['prompt_no_input'].format_map({"instruction":history[0][0]})
130
+ if len(history) > 1:
131
+ start_msg = start_msg + history[0][1] + "</s>"
132
+ for dialogue_his in history[1:-1]:
133
+ start_msg += f"<s>[INST] {dialogue_his[0]}[/INST]"
134
+ start_msg += f"{dialogue_his[1]}</s>"
135
+ if len(history) > 1:
136
+ start_msg += f"<s> [INST] {history[-1][0]} [/INST]"
137
+ print(f"input msg:{start_msg}")
138
+ return start_msg
139
+
140
+
141
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
142
+ logging_url = os.getenv("LOGGING_URL", None)
143
+ if logging_url is None:
144
+ return
145
+
146
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
147
+
148
+ data = {
149
+ "conversation_id": conversation_id,
150
+ "timestamp": timestamp,
151
+ "history": history,
152
+ "messages": messages,
153
+ "generate_kwargs": generate_kwargs,
154
+ }
155
+
156
+ try:
157
+ print(f"data:{data}")
158
+ requests.post(logging_url, json=data)
159
+ except requests.exceptions.RequestException as e:
160
+ print(f"Error logging conversation: {e}")
161
+
162
+
163
+ def user(message, history):
164
+ # Append the user's message to the conversation history
165
+ return "", history + [[message, ""]]
166
+
167
+
168
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
169
+ print(f"history: {history}")
170
+ # Initialize a StopOnTokens object
171
+ stop = StopOnTokens()
172
+
173
+ # Construct the input message string for the model by concatenating the current system message and conversation history
174
+ messages = convert_history_to_text(history)
175
+
176
+ # Tokenize the messages string
177
+ input_ids = tok(messages, return_tensors="pt").input_ids
178
+ input_ids = input_ids.to(m.device)
179
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
180
+ generate_kwargs = dict(
181
+ input_ids=input_ids,
182
+ max_new_tokens=max_new_tokens,
183
+ temperature=temperature,
184
+ do_sample=temperature > 0.0,
185
+ top_p=top_p,
186
+ top_k=top_k,
187
+ num_beams=1,
188
+ repetition_penalty=repetition_penalty,
189
+ streamer=streamer,
190
+ stopping_criteria=StoppingCriteriaList([stop]),
191
+ )
192
+
193
+ stream_complete = Event()
194
+
195
+ def generate_and_signal_complete():
196
+ m.generate(**generate_kwargs)
197
+ stream_complete.set()
198
+
199
+ def log_after_stream_complete():
200
+ stream_complete.wait()
201
+ log_conversation(
202
+ conversation_id,
203
+ history,
204
+ messages,
205
+ {
206
+ "top_k": top_k,
207
+ "top_p": top_p,
208
+ "temperature": temperature,
209
+ "repetition_penalty": repetition_penalty,
210
+ },
211
+ )
212
+
213
+ t1 = Thread(target=generate_and_signal_complete)
214
+ t1.start()
215
+
216
+ t2 = Thread(target=log_after_stream_complete)
217
+ t2.start()
218
+
219
+ # Initialize an empty string to store the generated text
220
+ partial_text = ""
221
+ for new_text in streamer:
222
+ partial_text += new_text
223
+ history[-1][1] = partial_text
224
+ yield history
225
+
226
+
227
+ def get_uuid():
228
+ return str(uuid4())
229
+
230
+
231
+ with gr.Blocks(
232
+ theme=gr.themes.Soft(),
233
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
234
+ ) as demo:
235
+ conversation_id = gr.State(get_uuid)
236
+ gr.Markdown(
237
+ """得物客服智能机器人
238
+ """
239
+ )
240
+ chatbot = gr.Chatbot().style(height=500)
241
+ with gr.Row():
242
+ with gr.Column():
243
+ msg = gr.Textbox(
244
+ label="Chat Message Box",
245
+ placeholder="聊天输入框",
246
+ show_label=False,
247
+ ).style(container=False)
248
+ with gr.Column():
249
+ with gr.Row():
250
+ submit = gr.Button("Submit")
251
+ stop = gr.Button("Stop")
252
+ clear = gr.Button("Clear")
253
+ with gr.Row():
254
+ with gr.Accordion("Advanced Options:", open=False):
255
+ with gr.Row():
256
+ with gr.Column():
257
+ with gr.Row():
258
+ temperature = gr.Slider(
259
+ label="Temperature",
260
+ value=0.8,
261
+ minimum=0.0,
262
+ maximum=1.0,
263
+ step=0.1,
264
+ interactive=True,
265
+ info="Higher values produce more diverse outputs",
266
+ )
267
+ with gr.Column():
268
+ with gr.Row():
269
+ top_p = gr.Slider(
270
+ label="Top-p (nucleus sampling)",
271
+ value=0.83,
272
+ minimum=0.0,
273
+ maximum=1,
274
+ step=0.01,
275
+ interactive=True,
276
+ info=(
277
+ "Sample from the smallest possible set of tokens whose cumulative probability "
278
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
279
+ ),
280
+ )
281
+ with gr.Column():
282
+ with gr.Row():
283
+ top_k = gr.Slider(
284
+ label="Top-k",
285
+ value=4,
286
+ minimum=0.0,
287
+ maximum=200,
288
+ step=1,
289
+ interactive=True,
290
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
291
+ )
292
+ with gr.Column():
293
+ with gr.Row():
294
+ repetition_penalty = gr.Slider(
295
+ label="Repetition Penalty",
296
+ value=1.3,
297
+ minimum=1.0,
298
+ maximum=2.0,
299
+ step=0.1,
300
+ interactive=True,
301
+ info="Penalize repetition — 1.0 to disable.",
302
+ )
303
+ # with gr.Column():
304
+ # with gr.Row():
305
+ # repetition_penalty = gr.Slider(
306
+ # label="beam_size",
307
+ # value=3,
308
+ # minimum=1,
309
+ # maximum=10,
310
+ # step=1,
311
+ # interactive=True,
312
+ # info="Penalize repetition — 1.0 to disable.",
313
+ # )
314
+ with gr.Row():
315
+ gr.Markdown(
316
+ "免责声明:该模型可能会产生与事实不符的输出,不应依赖该模型来产生与事实相符的信息。模型在各种公共数据集以及得物一些商品信息进行训练。尽管做了大量的数据清洗,但是模型的输出结果还可能存在一些问题",
317
+ elem_classes=["disclaimer"],
318
+ )
319
+ with gr.Row():
320
+ gr.Markdown(
321
+ "算法二组",
322
+ elem_classes=["disclaimer"],
323
+ )
324
+
325
+ submit_event = msg.submit(
326
+ fn=user,
327
+ inputs=[msg, chatbot],
328
+ outputs=[msg, chatbot],
329
+ queue=False,
330
+ ).then(
331
+ fn=bot,
332
+ inputs=[
333
+ chatbot,
334
+ temperature,
335
+ top_p,
336
+ top_k,
337
+ repetition_penalty,
338
+ conversation_id,
339
+ ],
340
+ outputs=chatbot,
341
+ queue=True,
342
+ )
343
+ submit_click_event = submit.click(
344
+ fn=user,
345
+ inputs=[msg, chatbot],
346
+ outputs=[msg, chatbot],
347
+ queue=False,
348
+ ).then(
349
+ fn=bot,
350
+ inputs=[
351
+ chatbot,
352
+ temperature,
353
+ top_p,
354
+ top_k,
355
+ repetition_penalty,
356
+ conversation_id,
357
+ ],
358
+ outputs=chatbot,
359
+ queue=True,
360
+ )
361
+ stop.click(
362
+ fn=None,
363
+ inputs=None,
364
+ outputs=None,
365
+ cancels=[submit_event, submit_click_event],
366
+ queue=False,
367
+ )
368
+ clear.click(lambda: None, None, chatbot, queue=False)
369
+
370
+
371
+ demo.queue(max_size=128, concurrency_count=2)
372
+
373
+ demo.launch(share=True)