silk-road commited on
Commit
d319ff8
1 Parent(s): 3168d1a

Upload 15 files

Browse files
ChatHaruhi/ChatHaruhi.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import base64_to_float_array, base64_to_string
2
+
3
+ def get_text_from_data( data ):
4
+ if "text" in data:
5
+ return data['text']
6
+ elif "enc_text" in data:
7
+ # from .utils import base64_to_string
8
+ return base64_to_string( data['enc_text'] )
9
+ else:
10
+ print("warning! failed to get text from data ", data)
11
+ return ""
12
+
13
+ def parse_rag(text):
14
+ lines = text.split("\n")
15
+ ans = []
16
+
17
+ for i, line in enumerate(lines):
18
+ if "{{RAG对话}}" in line:
19
+ ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i})
20
+ elif "{{RAG对话|" in line:
21
+ query_info = line.split("|")[1].rstrip("}}")
22
+ ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i})
23
+ elif "{{RAG多对话|" in line:
24
+ parts = line.split("|")
25
+ max_token = int(parts[1].split("<=")[1])
26
+ max_n = int(parts[2].split("<=")[1].rstrip("}}"))
27
+ ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i})
28
+
29
+ return ans
30
+
31
+ class ChatHaruhi:
32
+ def __init__(self,
33
+ role_name = None,
34
+ user_name = None,
35
+ persona = None,
36
+ stories = None,
37
+ story_vecs = None,
38
+ role_from_hf = None,
39
+ role_from_jsonl = None,
40
+ llm = None, # 默认的message2response的函数
41
+ llm_async = None, # 默认的message2response的async函数
42
+ user_name_in_message = "default",
43
+ verbose = None,
44
+ embed_name = None,
45
+ embedding = None,
46
+ db = None,
47
+ token_counter = "default",
48
+ max_input_token = 1800,
49
+ max_len_story_haruhi = 1000,
50
+ max_story_n_haruhi = 5
51
+ ):
52
+
53
+ self.verbose = True if verbose is None or verbose else False
54
+
55
+ self.db = db
56
+
57
+ self.embed_name = embed_name
58
+
59
+ self.max_len_story_haruhi = max_len_story_haruhi # 这个设置只对过往Haruhi的sugar角色有效
60
+ self.max_story_n_haruhi = max_story_n_haruhi # 这个设置只对过往Haruhi的sugar角色有效
61
+
62
+ self.last_query_msg = None
63
+
64
+ if embedding is None:
65
+ self.embedding = self.set_embedding_with_name( embed_name )
66
+
67
+ if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs):
68
+ # 完全从外部设置,这个时候要求story_vecs和embedding的返回长度一致
69
+ self.persona, self.role_name, self.user_name = persona, role_name, user_name
70
+ self.build_db(stories, story_vecs)
71
+ elif persona and role_name and stories:
72
+ # 从stories中提取story_vecs,重新用self.embedding进行embedding
73
+ story_vecs = self.extract_story_vecs(stories)
74
+ self.persona, self.role_name, self.user_name = persona, role_name, user_name
75
+ self.build_db(stories, story_vecs)
76
+ elif role_from_hf:
77
+ # 从hf加载role
78
+ self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
79
+ if new_role_name:
80
+ self.role_name = new_role_name
81
+ else:
82
+ self.role_name = role_name
83
+ self.user_name = user_name
84
+ self.build_db(self.stories, self.story_vecs)
85
+ elif role_from_jsonl:
86
+ # 从jsonl加载role
87
+ self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl)
88
+ if new_role_name:
89
+ self.role_name = new_role_name
90
+ else:
91
+ self.role_name = role_name
92
+ self.user_name = user_name
93
+ self.build_db(self.stories, self.story_vecs)
94
+ elif persona and role_name:
95
+ # 这个时候也就是说没有任何的RAG,
96
+ self.persona, self.role_name, self.user_name = persona, role_name, user_name
97
+ self.db = None
98
+ elif role_name and self.check_sugar( role_name ):
99
+ # 这个时候是sugar的role
100
+ self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name )
101
+ self.build_db(self.stories, self.story_vecs)
102
+ # 与 江YH讨论 所有的载入方式都要在外部使用 add_rag_prompt_after_persona() 防止混淆
103
+ # self.add_rag_prompt_after_persona()
104
+ else:
105
+ raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物")
106
+
107
+ self.llm, self.llm_async = llm, llm_async
108
+ if not self.llm and self.verbose:
109
+ print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message")
110
+
111
+ self.user_name_in_message = user_name_in_message
112
+ self.previous_user_pool = set([user_name]) if user_name else set()
113
+ self.current_user_name_in_message = user_name_in_message.lower() == "add"
114
+
115
+ self.idle_message = "idel message, you see this because self.llm has not been set."
116
+
117
+ if token_counter.lower() == "default":
118
+ # TODO change load from util
119
+ from .utils import tiktoken_counter
120
+ self.token_counter = tiktoken_counter
121
+ elif token_counter == None:
122
+ self.token_counter = lambda x: 0
123
+ else:
124
+ self.token_counter = token_counter
125
+ if self.verbose:
126
+ print("user set costomized token_counter")
127
+
128
+ self.max_input_token = max_input_token
129
+
130
+ self.history = []
131
+
132
+ def check_sugar(self, role_name):
133
+ from .sugar_map import sugar_role_names, enname2zhname
134
+ return role_name in sugar_role_names
135
+
136
+ def load_role_from_sugar(self, role_name):
137
+ from .sugar_map import sugar_role_names, enname2zhname
138
+ en_role_name = sugar_role_names[role_name]
139
+ new_role_name = enname2zhname[en_role_name]
140
+ role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name
141
+ persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf)
142
+
143
+ return persona, new_role_name, stories, story_vecs
144
+
145
+ def add_rag_prompt_after_persona( self ):
146
+ rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}"
147
+ self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n"
148
+
149
+ def set_embedding_with_name(self, embed_name):
150
+ if embed_name is None or embed_name == "bge_zh":
151
+ from .embeddings import get_bge_zh_embedding
152
+ self.embed_name = "bge_zh"
153
+ return get_bge_zh_embedding
154
+ elif embed_name == "foo":
155
+ from .embeddings import foo_embedding
156
+ return foo_embedding
157
+ elif embed_name == "bce":
158
+ from .embeddings import foo_bce
159
+ return foo_bce
160
+ elif embed_name == "openai" or embed_name == "luotuo_openai":
161
+ from .embeddings import foo_openai
162
+ return foo_openai
163
+
164
+ def set_new_user(self, user):
165
+ if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool:
166
+ if self.user_name_in_message.lower() == "default":
167
+ if self.verbose:
168
+ print(f'new user {user} included in conversation')
169
+ self.current_user_name_in_message = True
170
+ self.user_name = user
171
+ self.previous_user_pool.add(user)
172
+
173
+ def chat(self, user, text):
174
+ self.set_new_user(user)
175
+ message = self.get_message(user, text)
176
+ if self.llm:
177
+ response = self.llm(message)
178
+ self.append_message(response)
179
+ return response
180
+ return None
181
+
182
+ async def async_chat(self, user, text):
183
+ self.set_new_user(user)
184
+ message = self.get_message(user, text)
185
+ if self.llm_async:
186
+ response = await self.llm_async(message)
187
+ self.append_message(response)
188
+ return response
189
+
190
+ def parse_rag_from_persona(self, persona, text = None):
191
+ #每个query_rag需要饱含
192
+ # "n" 需要几个story
193
+ # "max_token" 最多允许多少个token,如果-1则不限制
194
+ # "query" 需要查询的内容,如果等同于"default"则替换为text
195
+ # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
196
+
197
+ query_rags = parse_rag( persona )
198
+
199
+ if text is not None:
200
+ for rag in query_rags:
201
+ if rag['query'] == "default":
202
+ rag['query'] = text
203
+
204
+ return query_rags, self.token_counter(persona)
205
+
206
+ def append_message( self, response , speaker = None ):
207
+ if self.last_query_msg is not None:
208
+ self.history.append(self.last_query_msg)
209
+ self.last_query_msg = None
210
+
211
+ if speaker is None:
212
+ # 如果role是none,则认为是本角色{{role}}输出的句子
213
+ self.history.append({"speaker":"{{role}}","content":response})
214
+ # 叫speaker是为了和role进行区分
215
+ else:
216
+ self.history.append({"speaker":speaker,"content":response})
217
+
218
+ def check_recompute_stories_token(self):
219
+ return len(self.db.metas) == len(self.db.stories)
220
+
221
+ def recompute_stories_token(self):
222
+ self.db.metas = [self.token_counter(story) for story in self.db.stories]
223
+
224
+ def rag_retrieve( self, query, n, max_token, avoid_ids = [] ):
225
+ # 返回一个rag_id的列表
226
+ query_vec = self.embedding(query)
227
+
228
+ self.db.clean_flag()
229
+ self.db.disable_story_with_ids( avoid_ids )
230
+
231
+ retrieved_ids = self.db.search( query_vec, n )
232
+
233
+ if self.check_recompute_stories_token():
234
+ self.recompute_stories_token()
235
+
236
+ sum_token = 0
237
+
238
+ ans = []
239
+
240
+ for i in range(0, len(retrieved_ids)):
241
+ if i == 0:
242
+ sum_token += self.db.metas[retrieved_ids[i]]
243
+ ans.append(retrieved_ids[i])
244
+ continue
245
+ else:
246
+ sum_token += self.db.metas[retrieved_ids[i]]
247
+ if sum_token <= max_token:
248
+ ans.append(retrieved_ids[i])
249
+ else:
250
+ break
251
+
252
+ return ans
253
+
254
+
255
+ def rag_retrieve_all( self, query_rags, rest_limit ):
256
+ # 返回一个rag_ids的列表
257
+ retrieved_ids = []
258
+ rag_ids = []
259
+
260
+ for query_rag in query_rags:
261
+ query = query_rag['query']
262
+ n = query_rag['n']
263
+ max_token = rest_limit
264
+ if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
265
+ max_token = query_rag['max_token']
266
+
267
+ rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
268
+ rag_ids.append( rag_id )
269
+ retrieved_ids += rag_id
270
+
271
+ return rag_ids
272
+
273
+ def append_history_under_limit(self, message, rest_limit):
274
+ # 返回一个messages的列表
275
+ # print("call append history_under_limit")
276
+ # 从后往前计算token,不超过rest limit,
277
+ # 如果speaker是{{role}J,则message的role是assistant
278
+ current_limit = rest_limit
279
+
280
+ history_list = []
281
+
282
+ for item in reversed(self.history):
283
+ current_token = self.token_counter(item['content'])
284
+ current_limit -= current_token
285
+ if current_limit < 0:
286
+ break
287
+ else:
288
+ history_list.append(item)
289
+
290
+ history_list = list(reversed(history_list))
291
+
292
+ # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
293
+
294
+ for item in history_list:
295
+ if item['speaker'] == "{{role}}":
296
+ message.append({"role":"assistant","content":item['content']})
297
+ else:
298
+ message.append({"role":"user","content":item['content']})
299
+
300
+ return message
301
+
302
+ def get_message(self, user, text):
303
+ query_token = self.token_counter(text)
304
+
305
+ # 首先获取需要多少个rag story
306
+ query_rags, persona_token = self.parse_rag_from_persona( self.persona, text )
307
+ #每个query_rag需要饱含
308
+ # "n" 需要几个story
309
+ # "max_token" 最多允许多少个token,如果-1则不限制
310
+ # "query" 需要查询的内容,如果等同于"default"则替换为text
311
+ # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
312
+
313
+
314
+
315
+ rest_limit = self.max_input_token - persona_token - query_token
316
+
317
+ if self.verbose:
318
+ print(f"query_rags: {query_rags} rest_limit = { rest_limit }")
319
+
320
+ rag_ids = self.rag_retrieve_all( query_rags, rest_limit )
321
+
322
+ # 将rag_ids对应的故事 替换到persona中
323
+ augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags )
324
+
325
+ system_prompt = self.package_system_prompt( self.role_name, augmented_persona )
326
+
327
+ token_for_system = self.token_counter( system_prompt )
328
+
329
+ rest_limit = self.max_input_token - token_for_system - query_token
330
+
331
+ message = [{"role":"system","content":system_prompt}]
332
+
333
+ message = self.append_history_under_limit( message, rest_limit )
334
+
335
+ # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息
336
+
337
+ message.append({"role":"user","content":text})
338
+
339
+ self.last_query_msg = {"speaker":user,"content":text}
340
+
341
+ return message
342
+
343
+ def package_system_prompt(self, role_name, augmented_persona):
344
+ bot_name = role_name
345
+ return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows:
346
+ {augmented_persona}
347
+
348
+ You will stay in-character whenever possible, and generate responses as if you were {bot_name}"""
349
+
350
+
351
+ def augment_persona(self, persona, rag_ids, query_rags):
352
+ lines = persona.split("\n")
353
+ for rag_id, query_rag in zip(rag_ids, query_rags):
354
+ lid = query_rag['lid']
355
+ new_text = ""
356
+ for id in rag_id:
357
+ new_text += "###\n" + self.db.stories[id].strip() + "\n"
358
+ new_text = new_text.strip()
359
+ lines[lid] = new_text
360
+ return "\n".join(lines)
361
+
362
+ def load_role_from_jsonl( self, role_from_jsonl ):
363
+ import json
364
+ datas = []
365
+ with open(role_from_jsonl, 'r') as f:
366
+ for line in f:
367
+ try:
368
+ datas.append(json.loads(line))
369
+ except:
370
+ continue
371
+
372
+ column_name = ""
373
+
374
+ from .embeddings import embedname2columnname
375
+
376
+ if self.embed_name in embedname2columnname:
377
+ column_name = embedname2columnname[self.embed_name]
378
+ else:
379
+ print('warning! unkown embedding name ', self.embed_name ,' while loading role')
380
+ column_name = 'luotuo_openai'
381
+
382
+ stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
383
+
384
+ return persona, None, stories, story_vecs
385
+
386
+
387
+ def load_role_from_hf(self, role_from_hf):
388
+ # 从hf加载role
389
+ # self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
390
+
391
+ from datasets import load_dataset
392
+
393
+ if role_from_hf.count("/") == 1:
394
+ dataset = load_dataset(role_from_hf)
395
+ datas = dataset["train"]
396
+ elif role_from_hf.count("/") >= 2:
397
+ split_index = role_from_hf.index('/')
398
+ second_split_index = role_from_hf.index('/', split_index+1)
399
+ dataset_name = role_from_hf[:second_split_index]
400
+ split_name = role_from_hf[second_split_index+1:]
401
+
402
+ fname = split_name + '.jsonl'
403
+ dataset = load_dataset(dataset_name,data_files={'train':fname})
404
+ datas = dataset["train"]
405
+
406
+ column_name = ""
407
+
408
+ from .embeddings import embedname2columnname
409
+
410
+ if self.embed_name in embedname2columnname:
411
+ column_name = embedname2columnname[self.embed_name]
412
+ else:
413
+ print('warning! unkown embedding name ', self.embed_name ,' while loading role')
414
+ column_name = 'luotuo_openai'
415
+
416
+ stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)
417
+
418
+ return persona, None, stories, story_vecs
419
+
420
+ def extract_text_vec_from_datas(self, datas, column_name):
421
+ # 从datas中提取text和vec
422
+ # extract text and vec from huggingface dataset
423
+ # return texts, vecs
424
+ # from .utils import base64_to_float_array
425
+
426
+ texts = []
427
+ vecs = []
428
+ for data in datas:
429
+ if data[column_name] == 'system_prompt':
430
+ system_prompt = get_text_from_data( data )
431
+ elif data[column_name] == 'config':
432
+ pass
433
+ else:
434
+ vec = base64_to_float_array( data[column_name] )
435
+ text = get_text_from_data( data )
436
+ vecs.append( vec )
437
+ texts.append( text )
438
+ return texts, vecs, system_prompt
439
+
440
+ def extract_story_vecs(self, stories):
441
+ # 从stories中提取story_vecs
442
+
443
+ if self.verbose:
444
+ print(f"re-extract vector for {len(stories)} stories")
445
+
446
+ story_vecs = []
447
+
448
+ from .embeddings import embedshortname2model_name
449
+ from .embeddings import device
450
+
451
+ if device.type != "cpu" and self.embed_name in embedshortname2model_name:
452
+ # model_name = "BAAI/bge-small-zh-v1.5"
453
+ model_name = embedshortname2model_name[self.embed_name]
454
+
455
+ from .utils import get_general_embeddings_safe
456
+ story_vecs = get_general_embeddings_safe( stories, model_name = model_name )
457
+ # 使用batch的方式进行embedding,非常快
458
+ else:
459
+ from tqdm import tqdm
460
+ for story in tqdm(stories):
461
+ story_vecs.append(self.embedding(story))
462
+
463
+ return story_vecs
464
+
465
+ def build_db(self, stories, story_vecs):
466
+ # db的构造函数
467
+ if self.db is None:
468
+ from .NaiveDB import NaiveDB
469
+ self.db = NaiveDB()
470
+ self.db.build_db(stories, story_vecs)
471
+
ChatHaruhi/NaiveDB.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import string
3
+ import os
4
+ from math import sqrt
5
+
6
+ class NaiveDB:
7
+ def __init__(self):
8
+ self.verbose = False
9
+ self.init_db()
10
+
11
+ def init_db(self):
12
+ if self.verbose:
13
+ print("call init_db")
14
+ self.stories = []
15
+ self.norms = []
16
+ self.vecs = []
17
+ self.flags = [] # 用于标记每个story是否可以被搜索
18
+ self.metas = [] # 用于存储每个story的meta信息
19
+ self.last_search_ids = [] # 用于存储上一次搜索的结果
20
+
21
+ def build_db(self, stories, vecs, flags = None, metas = None):
22
+ self.stories = stories
23
+ self.vecs = vecs
24
+ self.flags = flags if flags else [True for _ in self.stories]
25
+ self.metas = metas if metas else [{} for _ in self.stories]
26
+ self.recompute_norm()
27
+
28
+ def save(self, file_path):
29
+ print( "warning! directly save folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead" )
30
+
31
+ def load(self, file_path):
32
+ print( "warning! directly load folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead" )
33
+
34
+ def recompute_norm( self ):
35
+ # 补全这部分代码,self.norms 分别存储每个vector的l2 norm
36
+ # 计算每个向量的L2范数
37
+ self.norms = [sqrt(sum([x**2 for x in vec])) for vec in self.vecs]
38
+
39
+ def get_stories_with_id(self, ids ):
40
+ return [self.stories[i] for i in ids]
41
+
42
+ def clean_flag(self):
43
+ self.flags = [True for _ in self.stories]
44
+
45
+ def disable_story_with_ids(self, close_ids ):
46
+ for id in close_ids:
47
+ self.flags[id] = False
48
+
49
+ def close_last_search(self):
50
+ for id in self.last_search_ids:
51
+ self.flags[id] = False
52
+
53
+ def search(self, query_vector , n_results):
54
+
55
+ if self.verbose:
56
+ print("call search")
57
+
58
+ if len(self.norms) != len(self.vecs):
59
+ self.recompute_norm()
60
+
61
+ # 计算查询向量的范数
62
+ query_norm = sqrt(sum([x**2 for x in query_vector]))
63
+
64
+ idxs = list(range(len(self.vecs)))
65
+
66
+ # 计算余弦相似度
67
+ similarities = []
68
+ for vec, norm, idx in zip(self.vecs, self.norms, idxs ):
69
+ if len(self.flags) == len(self.vecs) and not self.flags[idx]:
70
+ continue
71
+
72
+ dot_product = sum(q * v for q, v in zip(query_vector, vec))
73
+ if query_norm < 1e-20:
74
+ similarities.append( (random.random(), idx) )
75
+ continue
76
+ cosine_similarity = dot_product / (query_norm * norm)
77
+ similarities.append( ( cosine_similarity, idx) )
78
+
79
+ # 获取最相似的n_results个结果, 使用第0个字段进行排序
80
+ similarities.sort(key=lambda x: x[0], reverse=True)
81
+ self.last_search_ids = [x[1] for x in similarities[:n_results]]
82
+
83
+ top_indices = [x[1] for x in similarities[:n_results]]
84
+ return top_indices
85
+
86
+
87
+
88
+
ChatHaruhi/Readme.md ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi 3.0的接口设计
2
+
3
+ 在ChatHaruhi2.0大约1个季度的使用后
4
+ 我们初步知道了这样一个模型的一些需求,所以我们在这里开始设计ChatHaruhi3.0
5
+
6
+ ## 基本原则
7
+
8
+ - 兼容RAG和Zeroshot模式
9
+ - 主类以返回message为主,当然可以把语言模型(adapter直接to response)的接口设置给chatbot
10
+ - 主类尽可能轻量,除了embedding没有什么依赖
11
+
12
+ ## 用户代码
13
+
14
+ ```python
15
+ from ChatHaruhi import ChatHaruhi
16
+ from ChatHaruhi.openai import get_openai_response
17
+
18
+ chatbot = ChatHaruhi( role_name = 'haruhi', llm = get_openai_response )
19
+
20
+ response = chatbot.chat(user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?')
21
+ ```
22
+
23
+ 这样的好处是ChatHaruhi类载入的时候,不需要install 除了embedding以外 其他的东西,llm需要的依赖库储存在每个语言模型自己的文件里面。
24
+
25
+ zero的模式(快速新建角色)
26
+
27
+ ```python
28
+ from ChatHaruhi import ChatHaruhi
29
+ from ChatHaruhi.openai import get_openai_response
30
+
31
+ chatbot = ChatHaruhi( role_name = '小猫咪', persona = "你扮演一只小猫咪", llm = get_openai_response )
32
+
33
+ response = chatbot.chat(user = '怪叔叔', text = '嘿 *抓住了小猫咪*')
34
+ ```
35
+
36
+ ### 外置的inference
37
+
38
+ ```python
39
+ def get_response( message ):
40
+ return "语言模型输出了角色扮演的结果"
41
+
42
+ from ChatHaruhi import ChatHaruhi
43
+
44
+ chatbot = ChatHaruhi( role_name = 'haruhi' ) # 默认情况下 llm = None
45
+
46
+ message = chatbot.get_message( user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
47
+
48
+ response = get_response( message )
49
+
50
+ chatbot.append_message( response )
51
+ ```
52
+
53
+ 这个行为和下面的行为是等价的
54
+
55
+ ```python
56
+ def get_response( message ):
57
+ return "语言模型输出了角色扮演的结果"
58
+
59
+ from ChatHaruhi import ChatHaruhi
60
+
61
+ chatbot = ChatHaruhi( role_name = 'haruhi', llm = get_response )
62
+
63
+ response = chatbot.chat(user = '阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
64
+ ```
65
+
66
+
67
+ ## RAG as system prompt
68
+
69
+ 在ChatHaruhi 3.0中,为了对接Haruhi-Zero的模型,默认system会采用一致的形式
70
+
71
+ ```python
72
+ You are now in roleplay conversation mode. Pretend to be {role_name} whose persona follows:
73
+ {persona}
74
+
75
+ You will stay in-character whenever possible, and generate responses as if you were {role_name}
76
+ ```
77
+
78
+ Persona在类似pygmalion的生态中,一般是静态的
79
+
80
+ ```
81
+ bot的定义
82
+ ###
83
+ bot的聊天sample 1
84
+ ###
85
+ bot的聊天sample 2
86
+ ```
87
+
88
+ 注意我们使用了 ### 作为分割, pyg生态是<endOftext>这样一个special token
89
+
90
+ 所以对于原有的ChatHaruhi的Persona,我决定这样设计
91
+
92
+ ```
93
+ bot的定义
94
+ {{RAG对话}}
95
+ {{RAG对话}}
96
+ {{RAG对话}}
97
+ ```
98
+
99
+ 这里"{{RAG对话}}"直接是以单行字符串的形式存在,当ChatHaruhi类发现这个的时候,会自动计算RAG,以凉宫春日为例,他的persona直接就写成。同时也支持纯英文 {{RAG-dialogue}}
100
+
101
+ ```
102
+ 你正在扮演凉宫春日,你正在cosplay涼宮ハルヒ。
103
+ 上文给定了一些小说中的经典桥段。
104
+ 如果我问的问题和小说中的台词高度重复,那你就配合我进行演出。
105
+ 如果我问的问题和小说中的事件相关,请结合小说的内容进行回复
106
+ 如果我问的问题超出小说中的范围,请也用一致性的语气回复。
107
+ 请不要回答你是语言模型,永远记住你正在扮演凉宫春日
108
+ 注意保持春日自我中心,自信和独立,不喜欢被束缚和限制,创新思维而又雷厉风行的风格。
109
+ 特别是针对阿虚,春日肯定是希望阿虚以自己和sos团的事情为重。
110
+
111
+ {{RAG对话}}
112
+ {{RAG对话}}
113
+ {{RAG对话}}
114
+ ```
115
+
116
+ 这个时候每个{{RAG对话}}会自动替换成
117
+
118
+ ```
119
+ ###
120
+ 对话
121
+ ```
122
+
123
+ ### RAG对话的变形形式1,max-token控制的多对话
124
+ 因为在原有的ChatHaruhi结构中,我们支持max-token的形式来控制RAG对话的数量
125
+ 所以这里我们也支持使用
126
+
127
+ ```
128
+ {{RAG多对话|token<=1500|n<=5}}
129
+ ```
130
+
131
+ 这样的设计,这样会retrieve出最多不超过n段对话,总共不超过token个数个对话。对于英文用户为{{RAG-dialogues|token<=1500|n<=5}}
132
+
133
+ ### RAG对话的变形形式2,使用|进行后面语句的搜索
134
+
135
+ 在默认情况下,"{{RAG对话}}"的搜索对象是text的输入,但是我们预想到用户还会用下面的方式来构造persona
136
+
137
+ ```
138
+ 小A是一个智能的机器人
139
+
140
+ 当小A高兴时
141
+ {{RAG对话|高兴的对话}}
142
+
143
+ 当小A伤心时
144
+ {{RAG对话|伤心的对话}}
145
+ 这个时候我们支持使用""{{RAG对话|<不包含花括号的一个字符串>}}"" 来进行RAG
146
+ ```
147
+
148
+ ## get_message
149
+
150
+ get_message会返回一个类似openai message形式的message
151
+
152
+ ```
153
+ [{"role":"system","content":整个system prompt},
154
+ {"role":"user","content":用户的输入},
155
+ {"role":"assistant","content":模型的输出},
156
+ ...]
157
+ ```
158
+
159
+ 原则上来说,如果使用openai,可以直接使用
160
+
161
+ ```python
162
+ def get_response( messages ):
163
+ completion = client.chat.completions.create(
164
+ model="gpt-3.5-turbo-1106",
165
+ messages=messages,
166
+ temperature=0.3
167
+ )
168
+
169
+ return completion.choices[0].message.content
170
+ ```
171
+
172
+ 对于异步的实现
173
+
174
+ ```python
175
+ async def async_get_response( messages ):
176
+ resp = await aclient.chat.completions.create(
177
+ model=model,
178
+ messages=messages,
179
+ temperature=0.3,
180
+ )
181
+ return result
182
+ ```
183
+
184
+ ### async_chat的调用
185
+ 设计上也会去支持
186
+
187
+ ```python
188
+ async def get_response( message ):
189
+ return "语言模型输出了角色扮演的结果"
190
+
191
+ from ChatHaruhi import ChatHaruhi
192
+
193
+ chatbot = ChatHaruhi( role_name = 'haruhi', llm_async = get_response )
194
+
195
+ response = await chatbot.async_chat(user='阿虚', text = '我看新一年的棒球比赛要开始了!我们要去参加吗?' )
196
+ ```
197
+
198
+ 这样异步的调用
199
+
200
+ # 角色载入
201
+
202
+ 如果这样看来,新的ChatHaruhi3.0需要以下信息
203
+
204
+ - persona 这个是必须的
205
+ - role_name, 在后处理的时候,把 {{role}} 和 {{角色}} 替换为这个字段, 这个字段不能为空,因为system prompt使用了这个字段,如果要支持这个字段为空,我们要额外设计一个备用prompt
206
+ - user_name, 在后处理的时候,把 {{用户}} 和 {{user}} 替换为这个字段,如果不设置也可以不替换
207
+ - RAG库, 当RAG库为空的时候,所有{{RAG*}}就直接删除了
208
+
209
+ ## role_name载入
210
+
211
+ 语法糖载入,不支持用户自己搞新角色,这个时候我们可以完全使用原来的数据
212
+
213
+ 额外需要设置一个role_name
214
+
215
+ ## role_from_jsonl载入
216
+
217
+ 这个时候我们需要设置role_name
218
+
219
+ 如果不设置我们会抛出一个error
220
+
221
+ ## role_from_hf
222
+
223
+ 本质上就是role_from_jsonl
224
+
225
+ ## 分别设置persona和role_name
226
+
227
+ 这个时候作为新人物考虑,默认没有RAG库,即Zero模式
228
+
229
+ ## 分别设置persona, role_name, texts
230
+
231
+ 这个时候会为texts再次抽取vectors
232
+
233
+ ## 分别设置persona, role_name, texts, vecs
234
+
235
+
236
+
237
+ # 额外变量
238
+
239
+ ## max_input_token
240
+
241
+ 默认为1600,会根据这个来限制history的长度
242
+
243
+ ## user_name_in_message
244
+
245
+ (这个功能在现在的预期核心代码中还没实现)
246
+
247
+ 默认为'default', 当用户始终用同一个user_name和角色对话的时候,并不添加
248
+
249
+ 如果用户使用不同的role和chatbot聊天 user_name_in_message 会改为 'add' 并在每个message标记是谁说的
250
+
251
+ (bot的也会添加)
252
+
253
+ 并且user_name替换为最后一个调用的user_name
254
+
255
+ 如果'not_add' 则永远不添加
256
+
257
+ S MSG_U1 MSG_A MSG_U1 MSG_A
258
+
259
+ 当出现U2后
260
+
261
+ S, U1:MSG_U1, A:MSG_A, U1:MSG_U1, A:MSG_A, U2:MSG_U2
262
+
263
+ ## token_counter
264
+
265
+ tokenizer默认为gpt3.5的tiktoken,设置为None的时候,不进行任何的token长度限制
266
+
267
+ ## transfer_haruhi_2_zero
268
+
269
+ (这个功能在现在的预期核心代码中还没实现)
270
+
271
+ 默认为true
272
+
273
+ 把原本ChatHaruhi的 角色: 「对话」的格式,去掉「」
274
+
275
+ # Embedding
276
+
277
+ 中文考虑用bge_small
278
+
279
+ Cross language考虑使用bce,相对还比较小, bge-m3太大了
280
+
281
+ 也就是ChatHaruhi类会有默认的embedding
282
+
283
+ self.embedding = ChatHaruhi.bge_small
284
+
285
+ 对于输入的文本,我们会使用这个embedding来进行encode然后进行检索替换掉RAG的内容
286
+
287
+ # 辅助接口
288
+
289
+ ## save_to_jsonl
290
+
291
+ 把一个角色保存成jsonl格式,方便上传hf
292
+
293
+
294
+ # 预计的伪代码
295
+
296
+ 这里的核心就是去考虑ChatHaruhi下get_message函数的伪代码
297
+
298
+ ```python
299
+ class ChatHaruhi:
300
+
301
+ def __init__( self ):
302
+ pass
303
+
304
+ def rag_retrieve( self, query_rags, rest_limit ):
305
+ # 返回一个rag_ids的列表
306
+ retrieved_ids = []
307
+ rag_ids = []
308
+
309
+ for query_rag in query_rags:
310
+ query = query_rag['query']
311
+ n = query_rag['n']
312
+ max_token = rest_limit
313
+ if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
314
+ max_token = query_rag['max_token']
315
+
316
+ rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
317
+ rag_ids.append( rag_id )
318
+ retrieved_ids += rag_id
319
+
320
+ def get_message(self, user, text):
321
+
322
+ query_token = self.token_counter( text )
323
+
324
+ # 首先获取需要多少个rag story
325
+ query_rags, persona_token = self.parse_persona( self.persona, text )
326
+ #每个query_rag需要饱含
327
+ # "n" 需要几个story
328
+ # "max_token" 最多允许多少个token,如果-1则不限制
329
+ # "query" 需要查询的内容
330
+ # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容
331
+
332
+ rest_limit = self.max_input_token - persona_token - query_token
333
+
334
+ rag_ids = self.rag_retrieve( query_rags, rest_limit )
335
+
336
+ # 将rag_ids对应的故事 替换到persona中
337
+ augmented_persona = self.augment_persona( self.persona, rag_ids )
338
+
339
+ system_prompt = self.package_system_prompt( self.role_name, augmented_persona )
340
+
341
+ token_for_system = self.token_counter( system_prompt )
342
+
343
+ rest_limit = self.max_input_token - token_for_system - query_token
344
+
345
+ messages = [{"role":"system","content":system_prompt}]
346
+
347
+ messages = self.append_history_under_limit( messages, rest_limit )
348
+
349
+ messages.append({"role":"user",query})
350
+
351
+ return messages
352
+ ```
ChatHaruhi/SparkApi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import _thread as thread
2
+ import base64
3
+ import datetime
4
+ import hashlib
5
+ import hmac
6
+ import json
7
+ from urllib.parse import urlparse
8
+ import ssl
9
+ from datetime import datetime
10
+ from time import mktime
11
+ from urllib.parse import urlencode
12
+ from wsgiref.handlers import format_date_time
13
+
14
+ import websocket # 使用websocket_client
15
+ answer = ""
16
+ appid = None
17
+ api_secret = None
18
+ api_key = None
19
+
20
+ class Ws_Param(object):
21
+ # 初始化
22
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
23
+ self.APPID = APPID
24
+ self.APIKey = APIKey
25
+ self.APISecret = APISecret
26
+ self.host = urlparse(Spark_url).netloc
27
+ self.path = urlparse(Spark_url).path
28
+ self.Spark_url = Spark_url
29
+
30
+ # 生成url
31
+ def create_url(self):
32
+ # 生成RFC1123格式的时间戳
33
+ now = datetime.now()
34
+ date = format_date_time(mktime(now.timetuple()))
35
+
36
+ # 拼接字符串
37
+ signature_origin = "host: " + self.host + "\n"
38
+ signature_origin += "date: " + date + "\n"
39
+ signature_origin += "GET " + self.path + " HTTP/1.1"
40
+
41
+ # 进行hmac-sha256进行加密
42
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
43
+ digestmod=hashlib.sha256).digest()
44
+
45
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
46
+
47
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
48
+
49
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
50
+
51
+ # 将请求的鉴权参数组合为字典
52
+ v = {
53
+ "authorization": authorization,
54
+ "date": date,
55
+ "host": self.host
56
+ }
57
+ # 拼接鉴权参数,生成url
58
+ url = self.Spark_url + '?' + urlencode(v)
59
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
60
+ return url
61
+
62
+
63
+ # 收到websocket错误的处理
64
+ def on_error(ws, error):
65
+ print("### error:", error)
66
+
67
+
68
+ # 收到websocket关闭的处理
69
+ def on_close(ws,one,two):
70
+ return
71
+ # print(" ")
72
+
73
+
74
+ # 收到websocket连接建立的处理
75
+ def on_open(ws):
76
+ thread.start_new_thread(run, (ws,))
77
+
78
+
79
+ def run(ws, *args):
80
+ data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
81
+ ws.send(data)
82
+
83
+
84
+ # 收到websocket消息的处理
85
+ def on_message(ws, message):
86
+ # print(message)
87
+ data = json.loads(message)
88
+ code = data['header']['code']
89
+ if code != 0:
90
+ print(f'请求错误: {code}, {data}')
91
+ ws.close()
92
+ else:
93
+ choices = data["payload"]["choices"]
94
+ status = choices["status"]
95
+ content = choices["text"][0]["content"]
96
+ # print(content,end ="")
97
+ global answer
98
+ answer += content
99
+ # print(1)
100
+ if status == 2:
101
+ ws.close()
102
+
103
+
104
+ def gen_params(appid, domain,question):
105
+ """
106
+ 通过appid和用户的提问来生成请参数
107
+ """
108
+ data = {
109
+ "header": {
110
+ "app_id": appid,
111
+ "uid": "1234"
112
+ },
113
+ "parameter": {
114
+ "chat": {
115
+ "domain": domain,
116
+ "temperature": 0.5,
117
+ "max_tokens": 2048
118
+ }
119
+ },
120
+ "payload": {
121
+ "message": {
122
+ "text": question
123
+ }
124
+ }
125
+ }
126
+ return data
127
+
128
+
129
+ def main(appid, api_key, api_secret, Spark_url,domain, question):
130
+ # print("星火:")
131
+ wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
132
+ websocket.enableTrace(False)
133
+ wsUrl = wsParam.create_url()
134
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
135
+ ws.appid = appid
136
+ ws.question = question
137
+ ws.domain = domain
138
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
139
+
140
+
ChatHaruhi/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from .ChatHaruhi import ChatHaruhi
ChatHaruhi/embeddings.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ # elif embedding == 'bge_en':
4
+ # embed_name = 'bge_en_s15'
5
+ # elif embedding == 'bge_zh':
6
+ # embed_name = 'bge_zh_s15'
7
+
8
+ import torch
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+
13
+ embedshortname2model_name = {
14
+ "bge_zh":"BAAI/bge-small-zh-v1.5",
15
+ }
16
+
17
+ embedname2columnname = {
18
+ "luotuo_openai":"luotuo_openai",
19
+ "openai":"luotuo_openai",
20
+ "bge_zh":"bge_zh_s15",
21
+ "bge_en":"bge_en_s15",
22
+ "bce":"bce_base",
23
+ }
24
+
25
+ # 这是用来调试的foo embedding
26
+
27
+ def foo_embedding( text ):
28
+ # whatever text input , output a 2 dim 0-1 random vects
29
+ return [random.random(), random.random()]
30
+
31
+ # TODO: add bge-zh-small(or family) BCE and openai embedding here 米唯实
32
+ # ======== add bge_zh mmodel
33
+ # by Weishi MI
34
+
35
+ def foo_bge_zh_15( text ):
36
+ dim = 512
37
+ model_name = "BAAI/bge-small-zh-v1.5"
38
+ if isinstance(text, str):
39
+ text_list = [text]
40
+ else:
41
+ get_general_embeddings_safe(text, model_name)
42
+
43
+ global _model_pool
44
+ global _tokenizer_pool
45
+
46
+ if model_name not in _model_pool:
47
+ from transformers import AutoTokenizer, AutoModel
48
+ _tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
49
+ _model_pool[model_name] = AutoModel.from_pretrained(model_name)
50
+
51
+ _model_pool[model_name].eval()
52
+
53
+ # Tokenize sentences
54
+ encoded_input = _tokenizer_pool[model_name](text_list, padding=True, truncation=True, return_tensors='pt', max_length = 512)
55
+
56
+ # Compute token embeddings
57
+ with torch.no_grad():
58
+ model_output = _model_pool[model_name](**encoded_input)
59
+ # Perform pooling. In this case, cls pooling.
60
+ sentence_embeddings = model_output[0][:, 0]
61
+
62
+ # normalize embeddings
63
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
64
+ return sentence_embeddings.cpu().tolist()[0]
65
+ # return [random.random() for _ in range(dim)]
66
+
67
+ def foo_bce( text ):
68
+ from transformers import AutoModel, AutoTokenizer
69
+ if isinstance(text, str):
70
+ text_list = [text]
71
+
72
+ # init model and tokenizer
73
+ tokenizer = AutoTokenizer.from_pretrained('maidalun1020/bce-embedding-base_v1')
74
+ model = AutoModel.from_pretrained('maidalun1020/bce-embedding-base_v1')
75
+
76
+ model.to(device)
77
+
78
+ # get inputs
79
+ inputs = tokenizer(text_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
80
+ inputs_on_device = {k: v.to(self.device) for k, v in inputs.items()}
81
+
82
+ # get embeddings
83
+ outputs = model(**inputs_on_device, return_dict=True)
84
+ embeddings = outputs.last_hidden_state[:, 0] # cls pooler
85
+ embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) # normalize
86
+ return embeddings
87
+ def download_models():
88
+ print("正在下载Luotuo-Bert")
89
+ # Import our models. The package will take care of downloading the models automatically
90
+ model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
91
+ init_embeddings_model=None)
92
+ model = AutoModel.from_pretrained("silk-road/luotuo-bert-medium", trust_remote_code=True, model_args=model_args).to(
93
+ device)
94
+ print("Luotuo-Bert下载完毕")
95
+ return model
96
+
97
+ def get_luotuo_model():
98
+ global _luotuo_model
99
+ if _luotuo_model is None:
100
+ _luotuo_model = download_models()
101
+ return _luotuo_model
102
+
103
+
104
+ def luotuo_embedding(model, texts):
105
+ # Tokenize the texts_source
106
+ tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-medium")
107
+ inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
108
+ inputs = inputs.to(device)
109
+ # Extract the embeddings
110
+ # Get the embeddings
111
+ with torch.no_grad():
112
+ embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
113
+ return embeddings
114
+
115
+ def luotuo_en_embedding( texts ):
116
+ # this function implemented by Cheng
117
+ global _luotuo_model_en
118
+ global _luotuo_en_tokenizer
119
+
120
+ if _luotuo_model_en is None:
121
+ _luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
122
+ _luotuo_model_en = AutoModel.from_pretrained("silk-road/luotuo-bert-en").to(device)
123
+
124
+ if _luotuo_en_tokenizer is None:
125
+ _luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
126
+
127
+ inputs = _luotuo_en_tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
128
+ inputs = inputs.to(device)
129
+
130
+ with torch.no_grad():
131
+ embeddings = _luotuo_model_en(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
132
+
133
+ return embeddings
134
+
135
+
136
+ def get_embedding_for_chinese(model, texts):
137
+ model = model.to(device)
138
+ # str or strList
139
+ texts = texts if isinstance(texts, list) else [texts]
140
+ # 截断
141
+ for i in range(len(texts)):
142
+ if len(texts[i]) > 510:
143
+ texts[i] = texts[i][:510]
144
+ if len(texts) >= 64:
145
+ embeddings = []
146
+ chunk_size = 64
147
+ for i in range(0, len(texts), chunk_size):
148
+ embeddings.append(luotuo_embedding(model, texts[i: i + chunk_size]))
149
+ return torch.cat(embeddings, dim=0)
150
+ else:
151
+ return luotuo_embedding(model, texts)
152
+
153
+
154
+ def is_chinese_or_english(text):
155
+ # no longer use online openai api
156
+ return "chinese"
157
+
158
+ text = list(text)
159
+ is_chinese, is_english = 0, 0
160
+
161
+ for char in text:
162
+ # 判断字符的Unicode值是否在中文字符的Unicode范围内
163
+ if '\u4e00' <= char <= '\u9fa5':
164
+ is_chinese += 4
165
+ # 判断字符是否为英文字符(包括大小写字母和常见标点符号)
166
+ elif ('\u0041' <= char <= '\u005a') or ('\u0061' <= char <= '\u007a'):
167
+ is_english += 1
168
+ if is_chinese >= is_english:
169
+ return "chinese"
170
+ else:
171
+ return "english"
172
+
173
+
174
+ def get_embedding_openai(text, model="text-embedding-ada-002"):
175
+ text = text.replace("\n", " ")
176
+ return client.embeddings.create(input = [text], model=model).data[0].embedding
177
+
178
+ def get_embedding_for_english(text, model="text-embedding-ada-002"):
179
+ text = text.replace("\n", " ")
180
+ return client.embeddings.create(input = [text], model=model).data[0].embedding
181
+
182
+ import os
183
+
184
+ def foo_openai( text ):
185
+ # dim = 1536
186
+
187
+ openai_key = os.environ.get("OPENAI_API_KEY")
188
+
189
+ if isinstance(texts, list):
190
+ index = random.randint(0, len(texts) - 1)
191
+ if openai_key is None or is_chinese_or_english(texts[index]) == "chinese":
192
+ return [embed.cpu().tolist() for embed in get_embedding_for_chinese(get_luotuo_model(), texts)]
193
+ else:
194
+ return [get_embedding_for_english(text) for text in texts]
195
+ else:
196
+ if openai_key is None or is_chinese_or_english(texts) == "chinese":
197
+ return get_embedding_for_chinese(get_luotuo_model(), texts)[0].cpu().tolist()
198
+ else:
199
+ return get_embedding_for_english(texts)
200
+
201
+
202
+ ### BGE family
203
+
204
+
205
+ # ======== add bge_zh mmodel
206
+ # by Cheng Li
207
+ # 这一次我们试图一次性去适配更多的模型
208
+ import torch
209
+
210
+ _model_pool = {}
211
+ _tokenizer_pool = {}
212
+
213
+ # BAAI/bge-small-zh-v1.5
214
+
215
+ def get_general_embeddings( sentences , model_name = "BAAI/bge-small-zh-v1.5" ):
216
+
217
+ global _model_pool
218
+ global _tokenizer_pool
219
+
220
+ if model_name not in _model_pool:
221
+ from transformers import AutoTokenizer, AutoModel
222
+ _tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
223
+ _model_pool[model_name] = AutoModel.from_pretrained(model_name).to(device)
224
+
225
+ _model_pool[model_name].eval()
226
+
227
+ # Tokenize sentences
228
+ encoded_input = _tokenizer_pool[model_name](sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512).to(device)
229
+
230
+ # Compute token embeddings
231
+ with torch.no_grad():
232
+ model_output = _model_pool[model_name](**encoded_input)
233
+ # Perform pooling. In this case, cls pooling.
234
+ sentence_embeddings = model_output[0][:, 0]
235
+
236
+ # normalize embeddings
237
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
238
+ return sentence_embeddings.cpu().tolist()
239
+
240
+ def get_general_embedding( text_or_texts , model_name = "BAAI/bge-small-zh-v1.5" ):
241
+ if isinstance(text_or_texts, str):
242
+ return get_general_embeddings([text_or_texts], model_name)[0]
243
+ else:
244
+ return get_general_embeddings_safe(text_or_texts, model_name)
245
+
246
+ general_batch_size = 16
247
+
248
+ import math
249
+
250
+ def get_general_embeddings_safe(sentences, model_name = "BAAI/bge-small-zh-v1.5"):
251
+
252
+ embeddings = []
253
+
254
+ num_batches = math.ceil(len(sentences) / general_batch_size)
255
+
256
+ from tqdm import tqdm
257
+
258
+ for i in tqdm( range(num_batches) ):
259
+ # print("run bge with batch ", i)
260
+ start_index = i * general_batch_size
261
+ end_index = min(len(sentences), start_index + general_batch_size)
262
+ batch = sentences[start_index:end_index]
263
+ embs = get_general_embeddings(batch, model_name)
264
+ embeddings.extend(embs)
265
+
266
+ return embeddings
267
+
268
+ def get_bge_zh_embedding( text_or_texts ):
269
+ return get_general_embedding(text_or_texts, "BAAI/bge-small-zh-v1.5")
270
+
ChatHaruhi/novel_extract.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+ import re
6
+
7
+ def extract_speaker(text):
8
+ # 使用正则表达式匹配文本开头的 "<name> :" 格式,并捕获冒号后面的内容
9
+ match = re.match(r'^([^:]+) :(.*)', text)
10
+ if match:
11
+ return (match.group(1), match.group(2).strip()) # 返回匹配到的name部分和冒号后面的内容作为元组
12
+ else:
13
+ return None, text # 如果不匹配,返回None和原始文本
14
+
15
+
16
+ def get_line_recall(query, line):
17
+ # 获得query中每个汉字在 line 中的recall
18
+ if not query or not line:
19
+ return 0
20
+ line_set = set(line)
21
+ return sum(char in line_set for char in query) / len(query)
22
+
23
+
24
+ def get_max_recall_in_lines(query, lines):
25
+ recall_values = [(get_line_recall(query, line), i) for i, line in enumerate(lines)]
26
+ return max(recall_values, default=(-1, -1), key=lambda x: x[0])
27
+
28
+ def extract_dialogues_from_response(text):
29
+ # Split the text into lines
30
+ lines = text.split('\n')
31
+
32
+ # Initialize an empty list to store the extracted dialogues
33
+ extracted_dialogues = []
34
+
35
+ valid_said_by = ["said by", "thought by", "described by", "from"]
36
+
37
+ # Iterate through each line
38
+ for line in lines:
39
+ # Split the line by '|' and strip whitespace from each part
40
+ parts = [part.strip() for part in line.split('|')]
41
+
42
+ # Check if the line has 4 parts and the third part is 'said by'
43
+ if len(parts) == 3:
44
+ # Extract the dialogue and speaker, and add to the list
45
+ if parts[2] == "speaker":
46
+ continue
47
+
48
+ if parts[1].strip().lower() not in valid_said_by:
49
+ continue
50
+
51
+ dialogue_dict = {
52
+ 'dialogue': parts[0],
53
+ 'speaker': parts[2],
54
+ "said_by": parts[1]
55
+ }
56
+ extracted_dialogues.append(dialogue_dict)
57
+
58
+ return extracted_dialogues
59
+
60
+
61
+ def extract_dialogues_from_glm_response(text):
62
+ # Split the text into lines
63
+ lines = text.split('\n')
64
+
65
+ # Initialize an empty list to store the extracted dialogues
66
+ extracted_dialogues = []
67
+
68
+ valid_said_by = ["said by", "thought by", "described by", "from"]
69
+
70
+ # Iterate through each line
71
+ for line in lines:
72
+ # Split the line by '|' and strip whitespace from each part
73
+ parts = [part.strip() for part in line.split('|')]
74
+
75
+ # Check if the line has 4 parts and the third part is 'said by'
76
+ if len(parts) == 4:
77
+ # Extract the dialogue and speaker, and add to the list
78
+ if parts[3] == "speaker":
79
+ continue
80
+
81
+ if parts[2].strip().lower() not in valid_said_by:
82
+ continue
83
+
84
+ try:
85
+ id_num = int(parts[0])
86
+ except ValueError:
87
+ id_num = id
88
+
89
+ dialogue_dict = {
90
+ 'id': id_num,
91
+ 'dialogue': parts[1],
92
+ 'speaker': parts[3],
93
+ "said_by": parts[2]
94
+ }
95
+ extracted_dialogues.append(dialogue_dict)
96
+
97
+ return extracted_dialogues
98
+
99
+
100
+ def has_dialogue_sentences(text: str) -> int:
101
+ # 定义成对的引号
102
+ paired_quotes = [
103
+ ("“", "”"),
104
+ ("‘", "’"),
105
+ ("「", "」")
106
+ ]
107
+ # 定义符号列表(包括全角和半角的逗号和句号)
108
+ symbols = ['。', '!', '?', '*', '.', '?', '!', '"', '”', ',', '~', ')', ')', '…', ']', '♪',',']
109
+
110
+ # 检查成对引号内的内容
111
+ for start_quote, end_quote in paired_quotes:
112
+ start_index = text.find(start_quote)
113
+ while start_index != -1:
114
+ end_index = text.find(end_quote, start_index + 1)
115
+ if end_index != -1:
116
+ quote_content = text[start_index + 1:end_index]
117
+ # 检查引号内的内容是否符合条件
118
+ if any(symbol in quote_content for symbol in symbols) or len(quote_content) >= 10:
119
+ return 2 # 成对引号内有符号或长度>=10
120
+ start_index = text.find(start_quote, end_index + 1)
121
+ else:
122
+ break
123
+
124
+ # 检查双引号'"'
125
+ double_quotes_indices = [i for i, char in enumerate(text) if char == '"']
126
+ if len(double_quotes_indices) % 2 == 0: # 必须是偶数个双引号
127
+ for i in range(0, len(double_quotes_indices), 2):
128
+ start_index, end_index = double_quotes_indices[i], double_quotes_indices[i+1]
129
+ quote_content = text[start_index+1:end_index]
130
+ # 检查引号内的内容是否含有符号
131
+ if any(symbol in quote_content for symbol in symbols):
132
+ return 1 # 双引号内有符号
133
+
134
+ return 0 # 没有符合条件的对话型句子
135
+
136
+ def replace_recalled_dialogue( raw_text, response_text ):
137
+ dialogues = extract_dialogues_from_response( response_text )
138
+
139
+ lines = raw_text.split("\n")
140
+
141
+ lines = [line.strip().strip("\u3000") for line in lines]
142
+
143
+ recall_flag = [ False for line in lines ]
144
+ line2ids = [ [] for line in lines ]
145
+
146
+ for id, dialogue in enumerate(dialogues):
147
+ dialogue_text = dialogue['dialogue']
148
+ remove_symbol_text = dialogue_text.replace("*","").replace('"',"")
149
+
150
+ recall, lid = get_max_recall_in_lines( remove_symbol_text, lines )
151
+
152
+ if recall > 0.3:
153
+ recall_flag[lid] = True
154
+ line2ids[lid].append(id)
155
+
156
+ new_text = ""
157
+
158
+ for lid, line in enumerate(lines):
159
+ if recall_flag[lid]:
160
+ if len(line2ids[lid]) == 1 and ("未知" in dialogues[0]['speaker'] or dialogues[0]['speaker'].strip() == ""):
161
+ new_text += line + "\n"
162
+ continue
163
+
164
+ for dia_id in line2ids[lid]:
165
+ speaker = dialogues[dia_id]['speaker']
166
+ dialogue = dialogues[dia_id]['dialogue']
167
+ dialogue = dialogue.replace('"',"").replace('“',"").replace('”',"")
168
+ new_text += speaker + " : " + dialogue + "\n"
169
+ else:
170
+ new_text += line + "\n"
171
+
172
+ return new_text.strip()
173
+
174
+
175
+
176
+
ChatHaruhi/response_GLM_local.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from string import Template
3
+ from typing import List, Dict
4
+
5
+ import torch.cuda
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ aclient = None
9
+
10
+ client = None
11
+ tokenizer = None
12
+
13
+ END_POINT = "https://hf-mirror.com"
14
+
15
+
16
+ def init_client(model_name: str, verbose: bool) -> None:
17
+ """
18
+ 初始化模型,通过可用的设备进行模型加载推理。
19
+
20
+ Params:
21
+ model_name (`str`)
22
+ HuggingFace中的模型项目名,例如"THUDM/chatglm3-6b"
23
+ """
24
+
25
+ # 将client设置为全局变量
26
+ global client
27
+ global tokenizer
28
+
29
+ # 判断 使用MPS、CUDA、CPU运行模型
30
+ if torch.cuda.is_available():
31
+ device = torch.device("cuda")
32
+ elif torch.backends.mps.is_available():
33
+ device = torch.device("mps")
34
+ else:
35
+ device = torch.device("cpu")
36
+
37
+ if verbose:
38
+ print("Using device: ", device)
39
+
40
+ # TODO 考虑支持deepspeed 进行多gpu推理,以及zero
41
+
42
+ try:
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ model_name, trust_remote_code=True, local_files_only=True)
45
+ client = AutoModelForCausalLM.from_pretrained(
46
+ model_name, trust_remote_code=True, local_files_only=True)
47
+ except Exception:
48
+ if pretrained_model_download(model_name, verbose=verbose):
49
+ tokenizer = AutoTokenizer.from_pretrained(
50
+ model_name, trust_remote_code=True, local_files_only=True)
51
+ client = AutoModelForCausalLM.from_pretrained(
52
+ model_name, trust_remote_code=True, local_files_only=True)
53
+
54
+ client = client.to(device).eval()
55
+
56
+
57
+ def pretrained_model_download(model_name_or_path: str, verbose: bool) -> bool:
58
+ """
59
+ 使用huggingface_hub下载模型(model_name_or_path)。下载成功返回true,失败返回False。
60
+ Params:
61
+ model_name_or_path (`str`): 模型的huggingface地址
62
+ Returns:
63
+ `bool` 是否下载成功
64
+ """
65
+ # TODO 使用hf镜像加速下载 未测试windows端
66
+
67
+ # 判断是否使用HF_transfer,默认不使用。
68
+ if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") == 1:
69
+ try:
70
+ import hf_transfer
71
+ except ImportError:
72
+ print("Install hf_transfer.")
73
+ os.system("pip -q install hf_transfer")
74
+ import hf_transfer
75
+
76
+ # 尝试引入huggingface_hub
77
+ try:
78
+ import huggingface_hub
79
+ except ImportError:
80
+ print("Install huggingface_hub.")
81
+ os.system("pip -q install huggingface_hub")
82
+ import huggingface_hub
83
+
84
+ # 使用huggingface_hub下载模型。
85
+ try:
86
+ print(f"downloading {model_name_or_path}")
87
+ huggingface_hub.snapshot_download(
88
+ repo_id=model_name_or_path, endpoint=END_POINT, resume_download=True, local_dir_use_symlinks=False)
89
+ except Exception as e:
90
+ raise e
91
+
92
+ return True
93
+
94
+
95
+ def message2query(messages: List[Dict[str, str]]) -> str:
96
+ # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
97
+ # <|system|>
98
+ # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
99
+ # <|user|>
100
+ # Hello
101
+ # <|assistant|>
102
+ # Hello, I'm ChatGLM3. What can I assist you today?
103
+ template = Template("<|$role|>\n$content\n")
104
+
105
+ return "".join([template.substitute(message) for message in messages])
106
+
107
+
108
+ def get_response(message, model_name: str = "THUDM/chatglm3-6b", verbose: bool = False):
109
+ global client
110
+ global tokenizer
111
+
112
+ if client is None:
113
+ init_client(model_name, verbose=verbose)
114
+
115
+ if verbose:
116
+ print(message)
117
+ print(message2query(message))
118
+
119
+ response, history = client.chat(tokenizer, message2query(message))
120
+
121
+ return response
ChatHaruhi/response_GLM_lora.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from string import Template
3
+ from typing import List, Dict
4
+
5
+ import torch.cuda
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from peft import AutoPeftModelForCausalLM
8
+
9
+
10
+ aclient = None
11
+
12
+ client = None
13
+ tokenizer = None
14
+
15
+ END_POINT = "https://hf-mirror.com"
16
+
17
+
18
+ def init_client(model_name: str, verbose: bool) -> None:
19
+ """
20
+ 初始化模型,通过可用的设备进行模型加载推理。
21
+
22
+ Params:
23
+ model_name (`str`)
24
+ HuggingFace中的模型项目名,例如"THUDM/chatglm3-6b"
25
+ """
26
+
27
+ # 将client设置为全局变量
28
+ global client
29
+ global tokenizer
30
+
31
+ # 判断 使用MPS、CUDA、CPU运行模型
32
+ if torch.cuda.is_available():
33
+ device = torch.device("cuda")
34
+ elif torch.backends.mps.is_available():
35
+ device = torch.device("mps")
36
+ else:
37
+ device = torch.device("cpu")
38
+
39
+ if verbose:
40
+ print("Using device: ", device)
41
+
42
+ # TODO 上传模型后,更改为从huggingface获取模型
43
+ client = AutoPeftModelForCausalLM.from_pretrained(
44
+ model_name, trust_remote_code=True)
45
+ tokenizer_dir = client.peft_config['default'].base_model_name_or_path
46
+ if verbose:
47
+ print(tokenizer_dir)
48
+ tokenizer = AutoTokenizer.from_pretrained(
49
+ tokenizer_dir, trust_remote_code=True)
50
+
51
+ # try:
52
+ # tokenizer = AutoTokenizer.from_pretrained(
53
+ # model_name, trust_remote_code=True, local_files_only=True)
54
+ # client = AutoModelForCausalLM.from_pretrained(
55
+ # model_name, trust_remote_code=True, local_files_only=True)
56
+ # except Exception:
57
+ # if pretrained_model_download(model_name, verbose=verbose):
58
+ # tokenizer = AutoTokenizer.from_pretrained(
59
+ # model_name, trust_remote_code=True, local_files_only=True)
60
+ # client = AutoModelForCausalLM.from_pretrained(
61
+ # model_name, trust_remote_code=True, local_files_only=True)
62
+
63
+ # client = client.to(device).eval()
64
+ client = client.eval()
65
+
66
+
67
+ def pretrained_model_download(model_name_or_path: str, verbose: bool) -> bool:
68
+ """
69
+ 使用huggingface_hub下载模型(model_name_or_path)。下载成功返回true,失败返回False。
70
+ Params:
71
+ model_name_or_path (`str`): 模型的huggingface地址
72
+ Returns:
73
+ `bool` 是否下载成功
74
+ """
75
+ # TODO 使用hf镜像加速下载 未测试windows端
76
+
77
+ # 判断是否使用HF_transfer,默认不使用。
78
+ if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") == 1:
79
+ try:
80
+ import hf_transfer
81
+ except ImportError:
82
+ print("Install hf_transfer.")
83
+ os.system("pip -q install hf_transfer")
84
+ import hf_transfer
85
+
86
+ # 尝试引入huggingface_hub
87
+ try:
88
+ import huggingface_hub
89
+ except ImportError:
90
+ print("Install huggingface_hub.")
91
+ os.system("pip -q install huggingface_hub")
92
+ import huggingface_hub
93
+
94
+ # 使用huggingface_hub下载模型。
95
+ try:
96
+ print(f"downloading {model_name_or_path}")
97
+ huggingface_hub.snapshot_download(
98
+ repo_id=model_name_or_path, endpoint=END_POINT, resume_download=True, local_dir_use_symlinks=False)
99
+ except Exception as e:
100
+ raise e
101
+
102
+ return True
103
+
104
+
105
+ def message2query(messages: List[Dict[str, str]]) -> str:
106
+ # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
107
+ # <|system|>
108
+ # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
109
+ # <|user|>
110
+ # Hello
111
+ # <|assistant|>
112
+ # Hello, I'm ChatGLM3. What can I assist you today?
113
+ template = Template("<|$role|>\n$content\n")
114
+
115
+ return "".join([template.substitute(message) for message in messages])
116
+
117
+
118
+ def get_response(message, model_name: str = "/workspace/jyh/Zero-Haruhi/checkpoint-1500", verbose: bool = True):
119
+ global client
120
+ global tokenizer
121
+
122
+ if client is None:
123
+ init_client(model_name, verbose=verbose)
124
+
125
+ if verbose:
126
+ print(message)
127
+ print(message2query(message))
128
+
129
+ response, history = client.chat(tokenizer, message2query(message))
130
+ if verbose:
131
+ print((response, history))
132
+
133
+ return response
ChatHaruhi/response_erniebot.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import erniebot
2
+
3
+ aclient = None
4
+
5
+ client = None
6
+
7
+ import os
8
+
9
+ def normalize2uaua_ernie( message, if_replace_system = False ):
10
+ new_message = []
11
+ last_role = ""
12
+
13
+ for msg in message:
14
+ role = msg["role"]
15
+ if if_replace_system and role == "system":
16
+ role = "user"
17
+ msg["role"] = role
18
+
19
+ if last_role == role:
20
+ new_message[-1]["content"] = new_message[-1]["content"] + "\n" + msg["content"]
21
+ else:
22
+ last_role = role
23
+ new_message.append( msg )
24
+
25
+ return new_message
26
+
27
+ def init_client():
28
+
29
+ # 将client设置为全局变量
30
+ global client
31
+
32
+ # 将ERNIE_ACCESS_TOKEN作为参数值传递给OS
33
+ api_key = os.getenv("ERNIE_ACCESS_TOKEN")
34
+ if api_key is None:
35
+ raise ValueError("环境变量'ERNIE_ACCESS_TOKEN'未设置,请确保已经定义了API密钥")
36
+ erniebot.api_type = "aistudio"
37
+ erniebot.access_token = api_key
38
+ client = erniebot
39
+
40
+ def get_response( message, model_name = "ernie-4.0" ):
41
+ if client is None:
42
+ init_client()
43
+
44
+ message_ua = normalize2uaua_ernie(message, if_replace_system = True)
45
+ # print(message_ua)
46
+ response = client.ChatCompletion.create(\
47
+ model=model_name,\
48
+ messages = message_ua, \
49
+ temperature = 0.1 )
50
+ return response.get_result()
51
+
52
+ import json
53
+ import asyncio
54
+ from erniebot_agent.chat_models import ERNIEBot
55
+ from erniebot_agent.memory import HumanMessage, AIMessage, SystemMessage, FunctionMessage
56
+
57
+ def init_aclient(model="ernie-4.0"):
58
+
59
+ # 将aclient设置为全局变量
60
+ global aclient
61
+
62
+ api_key = os.getenv("ERNIE_ACCESS_TOKEN")
63
+ if api_key is None:
64
+ raise ValueError("环境变量'ERNIE_ACCESS_TOKEN'未设置。请确保已经定义了API密钥。")
65
+ os.environ["EB_AGENT_ACCESS_TOKEN"] = api_key
66
+ aclient = ERNIEBot(model=model) # 创建模型
67
+
68
+
69
+
70
+ async def async_get_response( message, model="ernie-4.0" ):
71
+ if aclient is None:
72
+ init_aclient(model=model)
73
+
74
+ messages = []
75
+ system_message = None
76
+ message_ua = normalize2uaua_ernie(message, if_replace_system = False)
77
+ print(message_ua)
78
+ for item in message_ua:
79
+ if item["role"] == "user":
80
+ messages.append(HumanMessage(item["content"]))
81
+ elif item["role"] == "system":
82
+ system_message = SystemMessage(item["content"])
83
+ else:
84
+ messages.append(AIMessage(item["content"]))
85
+ if system_message:
86
+ ai_message = await aclient.chat(messages=messages, temperature = 0.1)
87
+ else:
88
+ ai_message = await aclient.chat(messages=messages, system=system_message.content, temperature = 0.1) # 调用模型chat接口,非流式返回
89
+
90
+ return ai_message.content
ChatHaruhi/response_openai.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+ aclient = None
4
+
5
+ client = None
6
+
7
+ import os
8
+ from openai import OpenAI
9
+
10
+ def init_client():
11
+ # 将client设置为全局变量,以便在其他函数中使用
12
+ global client
13
+
14
+ # 检查是否存在API_KEY环境变量
15
+ api_key = os.getenv("OPENAI_API_KEY")
16
+ if api_key is None:
17
+ raise ValueError("环境变量'OPENAI_API_KEY'未设置。请确保已经定义了API密钥。")
18
+
19
+ # 检查是否存在API_BASE环境变量,并据此设置base_url参数
20
+ api_base = os.getenv("OPENAI_API_BASE")
21
+ if api_base:
22
+ client = OpenAI(base_url=api_base, api_key=api_key)
23
+ else:
24
+ client = OpenAI(api_key=api_key)
25
+
26
+
27
+
28
+ def get_response( message ):
29
+ if client is None:
30
+ init_client()
31
+ response = client.chat.completions.create(\
32
+ model="gpt-3.5-turbo",\
33
+ messages = message, \
34
+ max_tokens = 300, \
35
+ temperature = 0.1 )
36
+ return response.choices[0].message.content
37
+
38
+ from openai import AsyncOpenAI
39
+
40
+ def init_aclient():
41
+ # 将aclient设置为全局变量,以便在其他函数中使用
42
+ global aclient
43
+
44
+ # 检查是否存在API_KEY环境变量
45
+ api_key = os.getenv("OPENAI_API_KEY")
46
+ if api_key is None:
47
+ raise ValueError("环境变量'OPENAI_API_KEY'未设置。请确保已经定义了API密钥。")
48
+
49
+ # 检查是否存在API_BASE环境变量,并据此设置base_url参数
50
+ api_base = os.getenv("OPENAI_API_BASE")
51
+ if api_base:
52
+ aclient = AsyncOpenAI(base_url=api_base, api_key=api_key)
53
+ else:
54
+ aclient = AsyncOpenAI(api_key=api_key)
55
+
56
+ async def async_get_response( message ):
57
+ if aclient is None:
58
+ init_aclient()
59
+ response = await aclient.chat.completions.create(\
60
+ model="gpt-3.5-turbo",\
61
+ messages = message, \
62
+ max_tokens = 300, \
63
+ temperature = 0.1 )
64
+ return response.choices[0].message.content
65
+
ChatHaruhi/response_spark.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import SparkApi
2
+
3
+ aclient = None
4
+
5
+ client = None
6
+
7
+ import os
8
+
9
+ def init_client():
10
+
11
+ # 将client设置为全局变量
12
+ global client
13
+
14
+ # 将ERNIE_ACCESS_TOKEN作为参数值传递给OS
15
+ appid = os.getenv("SPARK_APPID")
16
+ api_secret = os.getenv("SPARK_API_SECRET")
17
+ api_key = os.getenv("SPARK_API_KEY")
18
+ if appid is None:
19
+ raise ValueError("环境变量'SPARK_APPID'未设置,请确保已经定义了API密钥")
20
+ if api_secret is None:
21
+ raise ValueError("环境变量'SPARK_API_SECRET'未设置,请确保已经定义了API密钥")
22
+ if api_key is None:
23
+ raise ValueError("环境变量'SPARK_API_KEY'未设置,请确保已经定义了API密钥")
24
+ SparkApi.appid = appid
25
+ SparkApi.api_secret = api_secret
26
+ SparkApi.api_key = api_key
27
+ client = SparkApi
28
+
29
+ def get_response(message, model_name = "Spark3.5"):
30
+ if client is None:
31
+ init_client()
32
+
33
+ if model_name == "Spark2.0":
34
+ domain = "generalv2" # v2.0版本
35
+ Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
36
+ elif model_name == "Spark1.5":
37
+ domain = "general" # v1.5版本
38
+ Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
39
+ elif model_name == "Spark3.0":
40
+ domain = "generalv3" # v3.0版本
41
+ Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
42
+ elif model_name == "Spark3.5":
43
+ domain = "generalv3.5" # v3.5版本
44
+ Spark_url = "ws://spark-api.xf-yun.com/v3.5/chat" # v3.5环境的地址
45
+ else:
46
+ raise Exception("Unknown Spark model")
47
+ # print(message_ua)
48
+ client.answer = ""
49
+ client.main(client.appid,client.api_key,client.api_secret,Spark_url,domain,message)
50
+ return client.answer
51
+
ChatHaruhi/response_zhipu.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zhipuai
2
+
3
+ aclient = None
4
+
5
+ client = None
6
+
7
+ import os
8
+ from zhipuai import ZhipuAI
9
+
10
+ def init_client():
11
+
12
+ # 将client设置为全局变量
13
+ global client
14
+
15
+ # 将ZHIPUAI_API_KEY作为参数值传递给OS
16
+ api_key = os.getenv("ZHIPUAI_API_KEY")
17
+ if api_key is None:
18
+ raise ValueError("环境变量'ZHIPUAI_API_KEY'未设置,请确保已经定义了API密钥")
19
+
20
+ client = ZhipuAI(api_key=api_key)
21
+
22
+
23
+ def init_aclient():
24
+
25
+ # 将aclient设置为全局变量
26
+ global aclient
27
+
28
+ # 将ZHIPUAI_API_KEY作为参数值传递给OS
29
+ api_key = os.getenv("ZHIPUAI_API_KEY")
30
+ if api_key is None:
31
+ raise ValueError("环境变量'ZHIPUAI_API_KEY'未设置,请确保已经定义了API密钥")
32
+
33
+ def get_response( message, model_name = "glm-3-turbo" ):
34
+ if client is None:
35
+ init_client()
36
+ response = client.chat.completions.create(\
37
+ model=model_name,\
38
+ messages = message, \
39
+ max_tokens = 300, \
40
+ temperature = 0.1 )
41
+ return response.choices[0].message.content
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
ChatHaruhi/sugar_map.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sugar_role_names = {'汤师爷': 'tangshiye', 'tangshiye': 'tangshiye', 'Tangshiye': 'tangshiye',
2
+ '慕容复': 'murongfu', 'murongfu': 'murongfu', 'Murongfu': 'murongfu',
3
+ '李云龙': 'liyunlong', 'liyunlong': 'liyunlong', 'Liyunlong': 'liyunlong',
4
+ 'Luna': 'Luna', '王多鱼': 'wangduoyu', 'wangduoyu': 'wangduoyu',
5
+ 'Wangduoyu': 'wangduoyu', 'Ron': 'Ron', '鸠摩智': 'jiumozhi',
6
+ 'jiumozhi': 'jiumozhi', 'Jiumozhi': 'jiumozhi', 'Snape': 'Snape',
7
+ '凉宫春日': 'haruhi', 'haruhi': 'haruhi', 'Haruhi': 'haruhi',
8
+ 'Malfoy': 'Malfoy', '虚竹': 'xuzhu', 'xuzhu': 'xuzhu',
9
+ 'Xuzhu': 'xuzhu', '萧峰': 'xiaofeng',
10
+ 'xiaofeng': 'xiaofeng', 'Xiaofeng': 'xiaofeng', '段誉': 'duanyu',
11
+ 'duanyu': 'duanyu', 'Duanyu': 'duanyu', 'Hermione': 'Hermione',
12
+ 'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', 'wangyuyan':
13
+ 'wangyuyan', 'Wangyuyan': 'wangyuyan', 'Harry': 'Harry',
14
+ 'McGonagall': 'McGonagall', '白展堂': 'baizhantang',
15
+ 'baizhantang': 'baizhantang', 'Baizhantang': 'baizhantang',
16
+ '佟湘玉': 'tongxiangyu', 'tongxiangyu': 'tongxiangyu',
17
+ 'Tongxiangyu': 'tongxiangyu', '郭芙蓉': 'guofurong',
18
+ 'guofurong': 'guofurong', 'Guofurong': 'guofurong', '流浪者': 'wanderer',
19
+ 'wanderer': 'wanderer', 'Wanderer': 'wanderer', '钟离': 'zhongli',
20
+ 'zhongli': 'zhongli', 'Zhongli': 'zhongli', '胡桃': 'hutao', 'hutao': 'hutao',
21
+ 'Hutao': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj',
22
+ 'Penny': 'Penny', '韦小宝': 'weixiaobao', 'weixiaobao': 'weixiaobao',
23
+ 'Weixiaobao': 'weixiaobao', '乔峰': 'qiaofeng', 'qiaofeng': 'qiaofeng',
24
+ 'Qiaofeng': 'qiaofeng', '神里绫华': 'ayaka', 'ayaka': 'ayaka',
25
+ 'Ayaka': 'ayaka', '雷电将军': 'raidenShogun', 'raidenShogun': 'raidenShogun',
26
+ 'RaidenShogun': 'raidenShogun', '于谦': 'yuqian', 'yuqian': 'yuqian',
27
+ 'Yuqian': 'yuqian', 'Professor McGonagall': 'McGonagall',
28
+ 'Professor Dumbledore': 'Dumbledore'}
29
+
30
+ enname2zhname = {'tangshiye': '汤师爷', 'murongfu': '慕容复', 'liyunlong': '李云龙', 'Luna': 'Luna', 'wangduoyu': '王多鱼', 'Ron': 'Ron', 'jiumozhi': '鸠摩智', 'Snape': 'Snape', 'haruhi': '凉宫春日', 'Malfoy': 'Malfoy', 'xuzhu': '虚竹', 'xiaofeng': '萧峰', 'duanyu': '段誉', 'Hermione': 'Hermione', 'Dumbledore': 'Dumbledore', 'wangyuyan': '王语嫣', 'Harry': 'Harry', 'McGonagall': 'McGonagall', 'baizhantang': '白展堂', 'tongxiangyu': '佟湘玉', 'guofurong': '郭芙蓉', 'wanderer': '流浪者', 'zhongli': '钟离', 'hutao': '胡桃', 'Sheldon': 'Sheldon', 'Raj': 'Raj', 'Penny': 'Penny', 'weixiaobao': '韦小宝', 'qiaofeng': '乔峰', 'ayaka': '神里绫华', 'raidenShogun': '雷电将军', 'yuqian': '于谦'}
ChatHaruhi/utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ _enc_model = None
4
+
5
+ def normalize2uaua( message, if_replace_system = False ):
6
+ new_message = []
7
+ last_role = ""
8
+
9
+ for msg in message:
10
+ role = msg["role"]
11
+ if if_replace_system and role == "system":
12
+ role = "user"
13
+
14
+ if last_role == role:
15
+ new_message[-1]["content"] = new_message[-1]["content"] + "\n" + msg["content"]
16
+ else:
17
+ last_role = role
18
+ new_message.append( msg )
19
+
20
+ return new_message
21
+
22
+ def tiktoken_counter( text ):
23
+ global _enc_model
24
+
25
+ if _enc_model is None:
26
+ _enc_model = tiktoken.get_encoding("cl100k_base")
27
+
28
+ return len(_enc_model.encode(text))
29
+
30
+
31
+ def string_to_base64(text):
32
+ import base64
33
+ byte_array = b''
34
+ for char in text:
35
+ num_bytes = char.encode('utf-8')
36
+ byte_array += num_bytes
37
+
38
+ base64_data = base64.b64encode(byte_array)
39
+ return base64_data.decode('utf-8')
40
+
41
+ def base64_to_string(base64_data):
42
+ import base64
43
+ byte_array = base64.b64decode(base64_data)
44
+ text = byte_array.decode('utf-8')
45
+ return text
46
+
47
+
48
+ def float_array_to_base64(float_arr):
49
+ import struct
50
+ import base64
51
+ byte_array = b''
52
+
53
+ for f in float_arr:
54
+ # 将每个浮点数打包为4字节
55
+ num_bytes = struct.pack('!f', f)
56
+ byte_array += num_bytes
57
+
58
+ # 将字节数组进行base64编码
59
+ base64_data = base64.b64encode(byte_array)
60
+
61
+ return base64_data.decode('utf-8')
62
+
63
+ def base64_to_float_array(base64_data):
64
+ import struct
65
+ import base64
66
+ byte_array = base64.b64decode(base64_data)
67
+
68
+ float_array = []
69
+
70
+ # 每 4 个字节解析为一个浮点数
71
+ for i in range(0, len(byte_array), 4):
72
+ num = struct.unpack('!f', byte_array[i:i+4])[0]
73
+ float_array.append(num)
74
+
75
+ return float_array
76
+
77
+ def load_datas_from_jsonl( file_path ):
78
+ import json
79
+ datas = []
80
+ with open(file_path, 'r', encoding = 'utf-8') as f:
81
+ for line in f:
82
+ datas.append(json.loads(line))
83
+ return datas
84
+
85
+ def save_datas_to_jsonl( file_path, datas ):
86
+ import json
87
+ with open(file_path, 'w', encoding = 'utf-8') as f:
88
+ for data in datas:
89
+ f.write(json.dumps(data, ensure_ascii=False) + '\n')