File size: 16,865 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
from .ChromaDB import ChromaDB
import os

from .utils import luotuo_openai_embedding, tiktokenizer

from .utils import response_postprocess

def get_text_from_data( data ):
    if "text" in data:
        return data['text']
    elif "enc_text" in data:
        from .utils import base64_to_string
        return base64_to_string( data['enc_text'] )
    else:
        print("warning! failed to get text from data ", data)
        return ""

class ChatHaruhi:

    def __init__(self, system_prompt = None, \
                 role_name = None, role_from_hf = None,
                 role_from_jsonl = None,  \
                 story_db=None, story_text_folder = None, \
                 llm = 'openai', \
                 embedding = 'luotuo_openai', \
                 max_len_story = None, max_len_history = None,
                 verbose = False):
        super(ChatHaruhi, self).__init__()
        self.verbose = verbose

        # constants
        self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
        self.k_search = 19
        self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
        self.dialogue_divide_token = '\n###\n'
        self.dialogue_bra_token = '「'
        self.dialogue_ket_token = '」'

        if system_prompt:
            self.system_prompt = self.check_system_prompt( system_prompt )

        # TODO: embedding should be the seperately defined, so refactor this part later
        if llm == 'openai':
            # self.llm = LangChainGPT()
            self.llm, self.tokenizer = self.get_models('openai')
        elif llm == 'debug':
            self.llm, self.tokenizer = self.get_models('debug')
        elif llm == 'spark':
            self.llm, self.tokenizer = self.get_models('spark')
        elif llm == 'GLMPro':
            self.llm, self.tokenizer = self.get_models('GLMPro')
        elif llm == 'ChatGLM2GPT':
            self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
            self.story_prefix_prompt = '\n'
        elif llm == "BaiChuan2GPT":
            self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
        elif llm == "BaiChuanAPIGPT":
            self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
        elif llm == "ernie3.5":
            self.llm, self.tokenizer = self.get_models('ernie3.5')
        elif llm == "ernie4.0":
            self.llm, self.tokenizer = self.get_models('ernie4.0')
        elif "qwen" in llm:
            self.llm, self.tokenizer = self.get_models(llm)
        else:
            print(f'warning! undefined llm {llm}, use openai instead.')
            self.llm, self.tokenizer = self.get_models('openai')

        if embedding == 'luotuo_openai':
            self.embedding = luotuo_openai_embedding
        elif embedding == 'bge_en':
            from .utils import get_bge_embedding
            self.embedding = get_bge_embedding
        elif embedding == 'bge_zh':
            from .utils import get_bge_zh_embedding
            self.embedding = get_bge_zh_embedding
        else:
            print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
            self.embedding = luotuo_openai_embedding
        
        if role_name:
            # TODO move into a function
            from .role_name_to_file import get_folder_role_name
            # correct role_name to folder_role_name
            role_name, url = get_folder_role_name(role_name)

            unzip_folder = f'./temp_character_folder/temp_{role_name}'
            db_folder = os.path.join(unzip_folder, f'content/{role_name}')
            system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')

            if not os.path.exists(unzip_folder):
                # not yet downloaded
                # url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
                import requests, zipfile, io
                r = requests.get(url)
                z = zipfile.ZipFile(io.BytesIO(r.content))
                z.extractall(unzip_folder)

            if self.verbose:
                print(f'loading pre-defined character {role_name}...')
            
            self.db = ChromaDB()
            self.db.load(db_folder)
            self.system_prompt = self.check_system_prompt(system_prompt)
        elif role_from_hf:
            # TODO move into a function
            from datasets import load_dataset

            if role_from_hf.count("/") == 1:
                dataset = load_dataset(role_from_hf)
                datas = dataset["train"]
            elif role_from_hf.count("/") >= 2:
                split_index = role_from_hf.index('/') 
                second_split_index = role_from_hf.index('/', split_index+1)
                dataset_name = role_from_hf[:second_split_index] 
                split_name = role_from_hf[second_split_index+1:]
                
                fname = split_name + '.jsonl'
                dataset = load_dataset(dataset_name,data_files={'train':fname})
                datas = dataset["train"]
            
            if embedding == 'luotuo_openai':
                embed_name = 'luotuo_openai'
            elif embedding == 'bge_en':
                embed_name = 'bge_en_s15'
            elif embedding == 'bge_zh':
                embed_name = 'bge_zh_s15'
            else:
                print('warning! unkown embedding name ', embedding ,' while loading role')
                embed_name = 'luotuo_openai'

            texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)

            self.build_story_db_from_vec( texts, vecs )

        elif role_from_jsonl:
            import json
            datas = []
            with open( role_from_jsonl , encoding="utf-8") as f:
                for line in f:
                    try:
                        data = json.loads(line)
                        # 逐行处理JSON数据
                        datas.append(data)
                    except:
                        print("warning! failed to load json line ", line)

            if embedding == 'luotuo_openai':
                embed_name = 'luotuo_openai'
            elif embedding == 'bge_en':
                embed_name = 'bge_en_s15'
            elif embedding == 'bge_zh':
                embed_name = 'bge_zh_s15'
            else:
                print('warning! unkown embedding name ', embedding ,' while loading role')
                embed_name = 'luotuo_openai'

            texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)

            self.build_story_db_from_vec( texts, vecs )
            
        elif story_db:
            self.db = ChromaDB() 
            self.db.load(story_db)
        elif story_text_folder:
            # print("Building story database from texts...")
            self.db = self.build_story_db(story_text_folder) 
        else:
            self.db = None
            print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
            # raise ValueError("Either story_db or story_text_folder must be provided")
        

        self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')

        if max_len_history is not None:
            self.max_len_history = max_len_history
            # user setting will override default setting

        if max_len_story is not None:
            self.max_len_story = max_len_story
            # user setting will override default setting

        self.dialogue_history = []

    def extract_text_vec_from_datas( self, datas, embed_name ):
        # extract text and vec from huggingface dataset
        # return texts, vecs
        from .utils import base64_to_float_array

        texts = []
        vecs = []
        for data in datas:
            if data[embed_name] == 'system_prompt':
                system_prompt = get_text_from_data( data )
            elif data[embed_name] == 'config':
                pass
            else:
                vec = base64_to_float_array( data[embed_name] )
                text = get_text_from_data( data )
                vecs.append( vec )
                texts.append( text )
        return texts, vecs, system_prompt

        

    def check_system_prompt(self, system_prompt):
        # if system_prompt end with .txt, read the file with utf-8
        # else, return the string directly
        if system_prompt.endswith('.txt'):
            with open(system_prompt, 'r', encoding='utf-8') as f:
                return f.read()
        else:
            return system_prompt
    

    def get_models(self, model_name):

        # TODO: if output only require tokenizer model, no need to initialize llm
        
        # return the combination of llm, embedding and tokenizer
        if model_name == 'openai':
            from .LangChainGPT import LangChainGPT
            return (LangChainGPT(), tiktokenizer)
        elif model_name == 'debug':
            from .PrintLLM import PrintLLM
            return (PrintLLM(), tiktokenizer)
        elif model_name == 'spark':
            from .SparkGPT import SparkGPT
            return (SparkGPT(), tiktokenizer)
        elif model_name == 'GLMPro':
            from .GLMPro import GLMPro
            return (GLMPro(), tiktokenizer)
        elif model_name == 'ernie3.5':
            from .ErnieGPT import ErnieGPT
            return (ErnieGPT(), tiktokenizer)
        elif model_name == 'ernie4.0':
            from .ErnieGPT import ErnieGPT
            return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
        elif model_name == "ChatGLM2GPT":
            from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
            return (ChatGLM2GPT(), GLM_tokenizer)
        elif model_name == "BaiChuan2GPT":
            from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
            return (BaiChuan2GPT(), BaiChuan_tokenizer)
        elif model_name == "BaiChuanAPIGPT":
            from .BaiChuanAPIGPT import BaiChuanAPIGPT
            return (BaiChuanAPIGPT(), tiktokenizer)
        elif "qwen" in model_name:
            if model_name == "qwen118k_raw":
                from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
                return (Qwen118k2GPT(model = "Qwen/Qwen-1_8B-Chat"), Qwen_tokenizer)
            from huggingface_hub import HfApi 
            from huggingface_hub.hf_api import ModelFilter
            qwen_api = HfApi()
            qwen_models = qwen_api.list_models(
                filter = ModelFilter(model_name=model_name),
                author = "silk-road"             
            )
            qwen_models_id = []
            for qwen_model in qwen_models:
                qwen_models_id.append(qwen_model.id)
                # print(model.id)
            if "silk-road/" + model_name in qwen_models_id:
                from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
                return (Qwen118k2GPT(model = "silk-road/" + model_name), Qwen_tokenizer)
            else:
                print(f'warning! undefined model {model_name}, use openai instead.')
                from .LangChainGPT import LangChainGPT
                return (LangChainGPT(), tiktokenizer) 
            # print(models_id)
        else:
            print(f'warning! undefined model {model_name}, use openai instead.')
            from .LangChainGPT import LangChainGPT
            return (LangChainGPT(), tiktokenizer)
        
    def get_tokenlen_setting( self, model_name ):
        # return the setting of story and history token length
        if model_name == 'openai':
            return (1500, 1200)
        else:
            print(f'warning! undefined model {model_name}, use openai instead.')
            return (1500, 1200)
        
    def build_story_db_from_vec( self, texts, vecs ):
        self.db = ChromaDB()

        self.db.init_from_docs( vecs, texts)

    def build_story_db(self, text_folder):
        # 实现读取文本文件夹,抽取向量的逻辑
        db = ChromaDB()

        strs = []

        # scan all txt file from text_folder
        for file in os.listdir(text_folder):
            # if file name end with txt
            if file.endswith(".txt"):
                file_path = os.path.join(text_folder, file)
                with open(file_path, 'r', encoding='utf-8') as f:
                    strs.append(f.read())

        if self.verbose:
            print(f'starting extract embedding... for { len(strs) } files')

        vecs = []

        ## TODO: 建立一个新的embedding batch test的单元测试
        ## 新的支持list batch test的embedding代码
        ## 用新的代码替换下面的for循环
        ## Luotuo-bert-en也发布了,所以可以避开使用openai
        
        for mystr in strs:
            vecs.append(self.embedding(mystr))

        db.init_from_docs(vecs, strs)

        return db
    
    def save_story_db(self, db_path):
        self.db.save(db_path)

    def generate_prompt( self, text, role):
        from langchain.schema import (
            AIMessage,
            HumanMessage,
            SystemMessage
        )
        messages = self.generate_messages( text, role )
        prompt = ""
        for msg in messages:
            if isinstance(msg, HumanMessage):
                prompt += msg.content + "\n"
            elif isinstance(msg, AIMessage):
                prompt += msg.content + "\n"
            elif isinstance(msg, SystemMessage):
                prompt += msg.content + "\n"
        return prompt


    def generate_messages( self, text, role):
        # add system prompt
        self.llm.initialize_message()
        self.llm.system_message(self.system_prompt)

        # add story
        query = self.get_query_string(text, role)
        self.add_story( query )
        self.last_query = query

        # add query
        self.llm.user_message(query)

        return self.llm.messages
    
    def append_response( self, response, last_query = None ):
        if last_query == None:
            last_query_record = ""
            if hasattr( self, "last_query" ):
                last_query_record = self.last_query
        else:
            last_query_record = last_query

        # record dialogue history
        self.dialogue_history.append((last_query_record, response))
        
    def chat(self, text, role):
        # add system prompt
        self.llm.initialize_message()
        self.llm.system_message(self.system_prompt)
    

        # add story
        query = self.get_query_string(text, role)
        self.add_story( query )

        # add history
        self.add_history()

        # add query
        self.llm.user_message(query)
        
        # get response
        response_raw = self.llm.get_response()

        response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)

        # record dialogue history
        self.dialogue_history.append((query, response))



        return response
    
    def get_query_string(self, text, role):
        if role in self.narrator:
            return role + ":" + text
        else:
            return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
        
    def add_story(self, query):

        if self.db is None:
            return
        
        query_vec = self.embedding(query)

        stories = self.db.search(query_vec, self.k_search)
        
        story_string = self.story_prefix_prompt
        sum_story_token = self.tokenizer(story_string)
        
        for story in stories:
            story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
            if sum_story_token + story_token > self.max_len_story:
                break
            else:
                sum_story_token += story_token
                story_string += story + self.dialogue_divide_token

        self.llm.user_message(story_string)
        
    def add_history(self):

        if len(self.dialogue_history) == 0:
            return
        
        sum_history_token = 0
        flag = 0
        for query, response in reversed(self.dialogue_history):
            current_count = 0
            if query is not None:
                current_count += self.tokenizer(query) 
            if response is not None:
                current_count += self.tokenizer(response)
            sum_history_token += current_count
            if sum_history_token > self.max_len_history:
                break
            else:
                flag += 1

        if flag == 0:
            print('warning! no history added. the last dialogue is too long.')

        for (query, response) in self.dialogue_history[-flag:]:
            if query is not None:
                self.llm.user_message(query)
            if response is not None:
                self.llm.ai_message(response)