File size: 12,515 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from .ChromaDB import ChromaDB
import os

from .utils import luotuo_openai_embedding, tiktokenizer

from .utils import response_postprocess

from .utils import text_censor

class ChatHaruhi_safe:

    def __init__(self, system_prompt = None, \
                 role_name = None, role_from_hf = None, \
                 story_db=None, story_text_folder = None, \
                 llm = 'openai', \
                 embedding = 'luotuo_openai', \
                 max_len_story = None, max_len_history = None,
                 verbose = False):
        super(ChatHaruhi_safe, self).__init__()
        self.verbose = verbose

        # constants
        self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
        self.k_search = 19
        self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
        self.dialogue_divide_token = '\n###\n'
        self.dialogue_bra_token = '「'
        self.dialogue_ket_token = '」'

        if system_prompt:
            self.system_prompt = self.check_system_prompt( system_prompt )

        # TODO: embedding should be the seperately defined, so refactor this part later
        if llm == 'openai':
            # self.llm = LangChainGPT()
            self.llm, self.tokenizer = self.get_models('openai')
        elif llm == 'debug':
            self.llm, self.tokenizer = self.get_models('debug')
        elif llm == 'spark':
            self.llm, self.tokenizer = self.get_models('spark')
        elif llm == 'GLMPro':
            self.llm, self.tokenizer = self.get_models('GLMPro')
        elif llm == 'ChatGLM2GPT':
            self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
            self.story_prefix_prompt = '\n'
        elif llm == "BaiChuan2GPT":
            self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
        elif llm == "BaiChuanAPIGPT":
            self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
        elif llm == "ernie3.5":
            self.llm, self.tokenizer = self.get_models('ernie3.5')
        elif llm == "ernie4.0":
            self.llm, self.tokenizer = self.get_models('ernie4.0')
        else:
            print(f'warning! undefined llm {llm}, use openai instead.')
            self.llm, self.tokenizer = self.get_models('openai')

        if embedding == 'luotuo_openai':
            self.embedding = luotuo_openai_embedding
        elif embedding == 'bge_en':
            from .utils import get_bge_embedding
            self.embedding = get_bge_embedding
        else:
            print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
            self.embedding = luotuo_openai_embedding
        
        if role_name:
            # TODO move into a function
            from .role_name_to_file import get_folder_role_name
            # correct role_name to folder_role_name
            role_name, url = get_folder_role_name(role_name)

            unzip_folder = f'./temp_character_folder/temp_{role_name}'
            db_folder = os.path.join(unzip_folder, f'content/{role_name}')
            system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')

            if not os.path.exists(unzip_folder):
                # not yet downloaded
                # url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
                import requests, zipfile, io
                r = requests.get(url)
                z = zipfile.ZipFile(io.BytesIO(r.content))
                z.extractall(unzip_folder)

            if self.verbose:
                print(f'loading pre-defined character {role_name}...')
            
            self.db = ChromaDB()
            self.db.load(db_folder)
            self.system_prompt = self.check_system_prompt(system_prompt)
        elif role_from_hf:
            # TODO move into a function
            from datasets import load_dataset

            if role_from_hf.count("/") == 1:
                dataset = load_dataset(role_from_hf)
                datas = dataset["train"]
            elif role_from_hf.count("/") >= 2:
                split_index = role_from_hf.index('/') 
                second_split_index = role_from_hf.index('/', split_index+1)
                dataset_name = role_from_hf[:second_split_index] 
                split_name = role_from_hf[second_split_index+1:]
                
                fname = split_name + '.jsonl'
                dataset = load_dataset(dataset_name,data_files={'train':fname})
                datas = dataset["train"]


            from .utils import base64_to_float_array
            
            if embedding == 'luotuo_openai':
                embed_name = 'luotuo_openai'
            elif embedding == 'bge_en':
                embed_name = 'bge_en_s15'
            else:
                print('warning! unkown embedding name ', embedding ,' while loading role')
                embed_name = 'luotuo_openai'

            texts = []
            vecs = []
            for data in datas:
                if data[embed_name] == 'system_prompt':
                    self.system_prompt = data['text']
                elif data[embed_name] == 'config':
                    pass
                else:
                    vec = base64_to_float_array( data[embed_name] )
                    text = data['text']
                    vecs.append( vec )
                    texts.append( text )

            self.build_story_db_from_vec( texts, vecs )
            
        elif story_db:
            self.db = ChromaDB() 
            self.db.load(story_db)
        elif story_text_folder:
            # print("Building story database from texts...")
            self.db = self.build_story_db(story_text_folder) 
        else:
            self.db = None
            print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
            # raise ValueError("Either story_db or story_text_folder must be provided")
        

        self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')

        if max_len_history is not None:
            self.max_len_history = max_len_history
            # user setting will override default setting

        if max_len_story is not None:
            self.max_len_story = max_len_story
            # user setting will override default setting

        self.dialogue_history = []

        

    def check_system_prompt(self, system_prompt):
        # if system_prompt end with .txt, read the file with utf-8
        # else, return the string directly
        if system_prompt.endswith('.txt'):
            with open(system_prompt, 'r', encoding='utf-8') as f:
                return f.read()
        else:
            return system_prompt
    

    def get_models(self, model_name):

        # TODO: if output only require tokenizer model, no need to initialize llm
        
        # return the combination of llm, embedding and tokenizer
        if model_name == 'openai':
            from .LangChainGPT import LangChainGPT
            return (LangChainGPT(), tiktokenizer)
        elif model_name == 'debug':
            from .PrintLLM import PrintLLM
            return (PrintLLM(), tiktokenizer)
        elif model_name == 'spark':
            from .SparkGPT import SparkGPT
            return (SparkGPT(), tiktokenizer)
        elif model_name == 'GLMPro':
            from .GLMPro import GLMPro
            return (GLMPro(), tiktokenizer)
        elif model_name == 'ernie3.5':
            from .ErnieGPT import ErnieGPT
            return (ErnieGPT(), tiktokenizer)
        elif model_name == 'ernie4.0':
            from .ErnieGPT import ErnieGPT
            return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
        elif model_name == "ChatGLM2GPT":
            from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
            return (ChatGLM2GPT(), GLM_tokenizer)
        elif model_name == "BaiChuan2GPT":
            from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
            return (BaiChuan2GPT(), BaiChuan_tokenizer)
        elif model_name == "BaiChuanAPIGPT":
            from .BaiChuanAPIGPT import BaiChuanAPIGPT
            return (BaiChuanAPIGPT(), tiktokenizer)
        else:
            print(f'warning! undefined model {model_name}, use openai instead.')
            from .LangChainGPT import LangChainGPT
            return (LangChainGPT(), tiktokenizer)
        
    def get_tokenlen_setting( self, model_name ):
        # return the setting of story and history token length
        if model_name == 'openai':
            return (1500, 1200)
        else:
            print(f'warning! undefined model {model_name}, use openai instead.')
            return (1500, 1200)
        
    def build_story_db_from_vec( self, texts, vecs ):
        self.db = ChromaDB()

        self.db.init_from_docs( vecs, texts)

    def build_story_db(self, text_folder):
        # 实现读取文本文件夹,抽取向量的逻辑
        db = ChromaDB()

        strs = []

        # scan all txt file from text_folder
        for file in os.listdir(text_folder):
            # if file name end with txt
            if file.endswith(".txt"):
                file_path = os.path.join(text_folder, file)
                with open(file_path, 'r', encoding='utf-8') as f:
                    strs.append(f.read())

        if self.verbose:
            print(f'starting extract embedding... for { len(strs) } files')

        vecs = []

        ## TODO: 建立一个新的embedding batch test的单元测试
        ## 新的支持list batch test的embedding代码
        ## 用新的代码替换下面的for循环
        ## Luotuo-bert-en也发布了,所以可以避开使用openai
        
        for mystr in strs:
            vecs.append(self.embedding(mystr))

        db.init_from_docs(vecs, strs)

        return db
    
    def save_story_db(self, db_path):
        self.db.save(db_path)
        
    def chat(self, text, role):
        # add system prompt
        self.llm.initialize_message()
        self.llm.system_message(self.system_prompt)
    

        # add story
        query = self.get_query_string(text, role)
        self.add_story( query )

        # add history
        self.add_history()

        # add query
        self.llm.user_message(query)
        
        # get response
        response_raw = self.llm.get_response()

        response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)

        # record dialogue history
        self.dialogue_history.append((query, response))



        return response
    
    def get_query_string(self, text, role):
        if role in self.narrator:
            return role + ":" + text
        else:
            return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
        
    def add_story(self, query):

        if self.db is None:
            return
        
        query_vec = self.embedding(query)

        stories = self.db.search(query_vec, self.k_search)
        
        story_string = self.story_prefix_prompt
        sum_story_token = self.tokenizer(story_string)
        
        for story in stories:
            story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
            if sum_story_token + story_token > self.max_len_story:
                break
            else:
                sum_story_token += story_token
                story_string += story + self.dialogue_divide_token
        
        if text_censor(story_string):
            self.llm.user_message(story_string)
        
    def add_history(self):

        if len(self.dialogue_history) == 0:
            return
        
        sum_history_token = 0
        flag = 0
        for query, response in reversed(self.dialogue_history):
            current_count = 0
            if query is not None:
                current_count += self.tokenizer(query) 
            if response is not None:
                current_count += self.tokenizer(response)
            sum_history_token += current_count
            if sum_history_token > self.max_len_history:
                break
            else:
                flag += 1

        if flag == 0:
            print('warning! no history added. the last dialogue is too long.')

        for (query, response) in self.dialogue_history[-flag:]:
            if query is not None:
                self.llm.user_message(query)
            if response is not None:
                self.llm.ai_message(response)