File size: 18,176 Bytes
d319ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
from .utils import base64_to_float_array, base64_to_string

def get_text_from_data( data ):
    if "text" in data:
        return data['text']
    elif "enc_text" in data:
        # from .utils import base64_to_string
        return base64_to_string( data['enc_text'] )
    else:
        print("warning! failed to get text from data ", data)
        return ""

def parse_rag(text):
    lines = text.split("\n")
    ans = []

    for i, line in enumerate(lines):
        if "{{RAG对话}}" in line:
            ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i})
        elif "{{RAG对话|" in line:
            query_info = line.split("|")[1].rstrip("}}")
            ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i})
        elif "{{RAG多对话|" in line:
            parts = line.split("|")
            max_token = int(parts[1].split("<=")[1])
            max_n = int(parts[2].split("<=")[1].rstrip("}}"))
            ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i})
            
    return ans

class ChatHaruhi:
    def __init__(self,
                 role_name = None,
                 user_name = None,
                 persona = None,
                 stories = None,
                 story_vecs = None,
                 role_from_hf = None,
                 role_from_jsonl = None,
                 llm = None, # 默认的message2response的函数
                 llm_async = None, # 默认的message2response的async函数
                 user_name_in_message = "default",
                 verbose = None,
                 embed_name = None,
                 embedding = None,
                 db = None,
                 token_counter = "default",
                 max_input_token = 1800,
                 max_len_story_haruhi = 1000,
                 max_story_n_haruhi = 5
                 ):

        self.verbose = True if verbose is None or verbose else False

        self.db = db

        self.embed_name = embed_name

        self.max_len_story_haruhi = max_len_story_haruhi # 这个设置只对过往Haruhi的sugar角色有效
        self.max_story_n_haruhi = max_story_n_haruhi # 这个设置只对过往Haruhi的sugar角色有效

        self.last_query_msg = None

        if embedding is None:
            self.embedding = self.set_embedding_with_name( embed_name )

        if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs):
            # 完全从外部设置,这个时候要求story_vecs和embedding的返回长度一致
            self.persona, self.role_name, self.user_name = persona, role_name, user_name
            self.build_db(stories, story_vecs)
        elif persona and role_name and stories:
            # 从stories中提取story_vecs,重新用self.embedding进行embedding
            story_vecs = self.extract_story_vecs(stories)
            self.persona, self.role_name, self.user_name = persona, role_name, user_name
            self.build_db(stories, story_vecs)
        elif role_from_hf:
            # 从hf加载role
            self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)
            if new_role_name:
                self.role_name = new_role_name
            else:
                self.role_name = role_name
            self.user_name = user_name
            self.build_db(self.stories, self.story_vecs)
        elif role_from_jsonl:
            # 从jsonl加载role
            self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl)
            if new_role_name:
                self.role_name = new_role_name
            else:
                self.role_name = role_name
            self.user_name = user_name
            self.build_db(self.stories, self.story_vecs)
        elif persona and role_name:
            # 这个时候也就是说没有任何的RAG,
            self.persona, self.role_name, self.user_name = persona, role_name, user_name
            self.db = None
        elif role_name and self.check_sugar( role_name ):
            # 这个时候是sugar的role
            self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name )
            self.build_db(self.stories, self.story_vecs)
            # 与 江YH讨论 所有的载入方式都要在外部使用 add_rag_prompt_after_persona() 防止混淆
            # self.add_rag_prompt_after_persona()
        else:
            raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物")

        self.llm, self.llm_async = llm, llm_async
        if not self.llm and self.verbose:
            print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message")

        self.user_name_in_message = user_name_in_message
        self.previous_user_pool = set([user_name]) if user_name else set()
        self.current_user_name_in_message = user_name_in_message.lower() == "add"

        self.idle_message = "idel message, you see this because self.llm has not been set."

        if token_counter.lower() == "default":
            # TODO change load from util
            from .utils import tiktoken_counter
            self.token_counter = tiktoken_counter
        elif token_counter == None:
            self.token_counter = lambda x: 0
        else:
            self.token_counter = token_counter
            if self.verbose:
                print("user set costomized token_counter")

        self.max_input_token = max_input_token

        self.history = []

    def check_sugar(self, role_name):
        from .sugar_map import sugar_role_names, enname2zhname
        return role_name in sugar_role_names

    def load_role_from_sugar(self, role_name):
        from .sugar_map import sugar_role_names, enname2zhname
        en_role_name = sugar_role_names[role_name]
        new_role_name = enname2zhname[en_role_name]
        role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name
        persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf)

        return persona, new_role_name, stories, story_vecs

    def add_rag_prompt_after_persona( self ):
        rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}"
        self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n"

    def set_embedding_with_name(self, embed_name):
        if embed_name is None or embed_name == "bge_zh":
            from .embeddings import get_bge_zh_embedding
            self.embed_name = "bge_zh"
            return get_bge_zh_embedding
        elif embed_name == "foo":
            from .embeddings import foo_embedding
            return foo_embedding
        elif embed_name == "bce":
            from .embeddings import foo_bce
            return foo_bce
        elif embed_name == "openai" or embed_name == "luotuo_openai":
            from .embeddings import foo_openai
            return foo_openai

    def set_new_user(self, user):
        if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool:
            if self.user_name_in_message.lower() == "default":
                if self.verbose:
                    print(f'new user {user} included in conversation')
                self.current_user_name_in_message = True
        self.user_name = user
        self.previous_user_pool.add(user)

    def chat(self, user, text):
        self.set_new_user(user)
        message = self.get_message(user, text)
        if self.llm:
            response = self.llm(message)
            self.append_message(response)
            return response
        return None

    async def async_chat(self, user, text):
        self.set_new_user(user)
        message = self.get_message(user, text)
        if self.llm_async:
            response = await self.llm_async(message)
            self.append_message(response)
            return response

    def parse_rag_from_persona(self, persona, text = None):
        #每个query_rag需要饱含
        # "n" 需要几个story
        # "max_token" 最多允许多少个token,如果-1则不限制
        # "query" 需要查询的内容,如果等同于"default"则替换为text
        # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容

        query_rags = parse_rag( persona )

        if text is not None:
            for rag in query_rags:
                if rag['query'] == "default":
                    rag['query'] = text

        return query_rags, self.token_counter(persona)

    def append_message( self, response , speaker = None ):
        if self.last_query_msg is not None:
            self.history.append(self.last_query_msg)
            self.last_query_msg = None

        if speaker is None:
            # 如果role是none,则认为是本角色{{role}}输出的句子
            self.history.append({"speaker":"{{role}}","content":response})
            # 叫speaker是为了和role进行区分
        else:
            self.history.append({"speaker":speaker,"content":response})

    def check_recompute_stories_token(self):
        return len(self.db.metas) == len(self.db.stories)
    
    def recompute_stories_token(self):
        self.db.metas = [self.token_counter(story) for story in self.db.stories]

    def rag_retrieve( self, query, n, max_token, avoid_ids = [] ):
        # 返回一个rag_id的列表
        query_vec = self.embedding(query)

        self.db.clean_flag()
        self.db.disable_story_with_ids( avoid_ids )
        
        retrieved_ids = self.db.search( query_vec, n )

        if self.check_recompute_stories_token():
            self.recompute_stories_token()

        sum_token = 0

        ans = []

        for i in range(0, len(retrieved_ids)):
            if i == 0:
                sum_token += self.db.metas[retrieved_ids[i]]
                ans.append(retrieved_ids[i])
                continue
            else:
                sum_token += self.db.metas[retrieved_ids[i]]
                if sum_token <= max_token:
                    ans.append(retrieved_ids[i])
                else:
                    break
                
        return ans


    def rag_retrieve_all( self, query_rags, rest_limit ):
        # 返回一个rag_ids的列表
        retrieved_ids = []
        rag_ids = []

        for query_rag in query_rags:
            query = query_rag['query']
            n = query_rag['n']
            max_token = rest_limit
            if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0:
                max_token = query_rag['max_token']

            rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids )
            rag_ids.append( rag_id )
            retrieved_ids += rag_id

        return rag_ids

    def append_history_under_limit(self, message, rest_limit):
        # 返回一个messages的列表
        # print("call append history_under_limit")
        # 从后往前计算token,不超过rest limit,
        # 如果speaker是{{role}J,则message的role是assistant
        current_limit = rest_limit

        history_list = []

        for item in reversed(self.history):
            current_token = self.token_counter(item['content'])
            current_limit -= current_token
            if current_limit < 0:
                break
            else:
                history_list.append(item)

        history_list = list(reversed(history_list))

        # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息

        for item in history_list:
            if item['speaker'] == "{{role}}":
                message.append({"role":"assistant","content":item['content']})
            else:
                message.append({"role":"user","content":item['content']})
        
        return message

    def get_message(self, user, text):
        query_token = self.token_counter(text)

        # 首先获取需要多少个rag story
        query_rags, persona_token = self.parse_rag_from_persona( self.persona, text )
        #每个query_rag需要饱含
        # "n" 需要几个story
        # "max_token" 最多允许多少个token,如果-1则不限制
        # "query" 需要查询的内容,如果等同于"default"则替换为text
        # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容

        

        rest_limit = self.max_input_token - persona_token - query_token

        if self.verbose:
            print(f"query_rags: {query_rags} rest_limit = { rest_limit }")

        rag_ids = self.rag_retrieve_all( query_rags, rest_limit )

        # 将rag_ids对应的故事 替换到persona中
        augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags )

        system_prompt = self.package_system_prompt( self.role_name, augmented_persona )

        token_for_system = self.token_counter( system_prompt )

        rest_limit = self.max_input_token - token_for_system - query_token

        message = [{"role":"system","content":system_prompt}]

        message = self.append_history_under_limit( message, rest_limit )

        # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息

        message.append({"role":"user","content":text})

        self.last_query_msg = {"speaker":user,"content":text}

        return message

    def package_system_prompt(self, role_name, augmented_persona):
        bot_name = role_name
        return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows:
{augmented_persona}

You will stay in-character whenever possible, and generate responses as if you were {bot_name}"""


    def augment_persona(self, persona, rag_ids, query_rags):
        lines = persona.split("\n")
        for rag_id, query_rag in zip(rag_ids, query_rags):
            lid = query_rag['lid']
            new_text = ""
            for id in rag_id:
                new_text += "###\n" + self.db.stories[id].strip() + "\n"
            new_text = new_text.strip()
            lines[lid] = new_text
        return "\n".join(lines)

    def load_role_from_jsonl( self, role_from_jsonl ):
        import json
        datas = []
        with open(role_from_jsonl, 'r') as f:
            for line in f:
                try:
                    datas.append(json.loads(line))
                except:
                    continue

        column_name = ""

        from .embeddings import embedname2columnname

        if self.embed_name in embedname2columnname:
            column_name = embedname2columnname[self.embed_name]
        else:
            print('warning! unkown embedding name ', self.embed_name ,' while loading role')
            column_name = 'luotuo_openai'

        stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)

        return persona, None, stories, story_vecs


    def load_role_from_hf(self, role_from_hf):
        # 从hf加载role
        # self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf)

        from datasets import load_dataset

        if role_from_hf.count("/") == 1:
            dataset = load_dataset(role_from_hf)
            datas = dataset["train"]
        elif role_from_hf.count("/") >= 2:
            split_index = role_from_hf.index('/')
            second_split_index = role_from_hf.index('/', split_index+1)
            dataset_name = role_from_hf[:second_split_index]
            split_name = role_from_hf[second_split_index+1:]

            fname = split_name + '.jsonl'
            dataset = load_dataset(dataset_name,data_files={'train':fname})
            datas = dataset["train"]

        column_name = ""

        from .embeddings import embedname2columnname

        if self.embed_name in embedname2columnname:
            column_name = embedname2columnname[self.embed_name]
        else:
            print('warning! unkown embedding name ', self.embed_name ,' while loading role')
            column_name = 'luotuo_openai'

        stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name)

        return persona, None, stories, story_vecs

    def extract_text_vec_from_datas(self, datas, column_name):
        # 从datas中提取text和vec
        # extract text and vec from huggingface dataset
        # return texts, vecs
        # from .utils import base64_to_float_array

        texts = []
        vecs = []
        for data in datas:
            if data[column_name] == 'system_prompt':
                system_prompt = get_text_from_data( data )
            elif data[column_name] == 'config':
                pass
            else:
                vec = base64_to_float_array( data[column_name] )
                text = get_text_from_data( data )
                vecs.append( vec )
                texts.append( text )
        return texts, vecs, system_prompt

    def extract_story_vecs(self, stories):
        # 从stories中提取story_vecs

        if self.verbose:
            print(f"re-extract vector for {len(stories)} stories")
        
        story_vecs = []

        from .embeddings import embedshortname2model_name
        from .embeddings import device

        if device.type != "cpu" and self.embed_name in embedshortname2model_name:
            # model_name = "BAAI/bge-small-zh-v1.5"
            model_name = embedshortname2model_name[self.embed_name]

            from .utils import get_general_embeddings_safe
            story_vecs = get_general_embeddings_safe( stories, model_name = model_name )
            # 使用batch的方式进行embedding,非常快
        else:
            from tqdm import tqdm
            for story in tqdm(stories):
                story_vecs.append(self.embedding(story))

        return story_vecs

    def build_db(self, stories, story_vecs):
        # db的构造函数
        if self.db is None:
            from .NaiveDB import NaiveDB
            self.db = NaiveDB()
        self.db.build_db(stories, story_vecs)