silk-road commited on
Commit
fee0ada
1 Parent(s): 0df524c

Upload 18 files

Browse files
ChatHaruhi/BaiChuan2GPT.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .BaseLLM import BaseLLM
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers.generation.utils import GenerationConfig
5
+ from peft import PeftModel
6
+
7
+ tokenizer_BaiChuan = None
8
+ model_BaiChuan = None
9
+
10
+ def initialize_BaiChuan2LORA():
11
+ global model_BaiChuan, tokenizer_BaiChuan
12
+
13
+ if model_BaiChuan is None:
14
+ model_BaiChuan = AutoModelForCausalLM.from_pretrained(
15
+ "baichuan-inc/Baichuan2-13B-Chat",
16
+ device_map="auto",
17
+ torch_dtype=torch.bfloat16,
18
+ trust_remote_code=True,
19
+ )
20
+ model_BaiChuan = PeftModel.from_pretrained(
21
+ model_BaiChuan,
22
+ "silk-road/Chat-Haruhi-Fusion_Baichuan2_13B"
23
+ )
24
+ model_BaiChuan.generation_config = GenerationConfig.from_pretrained(
25
+ "baichuan-inc/Baichuan2-13B-Chat"
26
+ )
27
+
28
+ if tokenizer_BaiChuan is None:
29
+ tokenizer_BaiChuan = AutoTokenizer.from_pretrained(
30
+ "baichuan-inc/Baichuan2-13B-Chat",
31
+ use_fast=True,
32
+ trust_remote_code=True
33
+ )
34
+
35
+ return model_BaiChuan, tokenizer_BaiChuan
36
+
37
+ def BaiChuan_tokenizer(text):
38
+ return len(tokenizer_BaiChuan.encode(text))
39
+
40
+ class BaiChuan2GPT(BaseLLM):
41
+ def __init__(self, model = "haruhi-fusion-baichuan"):
42
+ super(BaiChuan2GPT, self).__init__()
43
+ if model == "baichuan2-13b":
44
+ self.tokenizer = AutoTokenizer.from_pretrained(
45
+ "baichuan-inc/Baichuan2-13B-Chat",
46
+ use_fast=True,
47
+ trust_remote_code=True
48
+ ),
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ "baichuan-inc/Baichuan2-13B-Chat",
51
+ device_map="auto",
52
+ torch_dtype=torch.bfloat16,
53
+ trust_remote_code=True,
54
+ )
55
+ self.model.generation_config = GenerationConfig.from_pretrained(
56
+ "baichuan-inc/Baichuan2-13B-Chat"
57
+ )
58
+ elif model == "haruhi-fusion-baichuan":
59
+ self.model, self.tokenizer = initialize_BaiChuan2LORA()
60
+ else:
61
+ raise Exception("Unknown BaiChuan Model! Currently supported: [BaiChuan2-13B, haruhi-fusion-baichuan]")
62
+ self.messages = []
63
+
64
+ def initialize_message(self):
65
+ self.messages = []
66
+
67
+ def ai_message(self, payload):
68
+ self.messages.append({"role": "assistant", "content": payload})
69
+
70
+ def system_message(self, payload):
71
+ self.messages.append({"role": "system", "content": payload})
72
+
73
+ def user_message(self, payload):
74
+ self.messages.append({"role": "user", "content": payload})
75
+
76
+ def get_response(self):
77
+ with torch.no_grad():
78
+ response = self.model.chat(self.tokenizer, self.messages)
79
+ return response
80
+
81
+ def print_prompt(self):
82
+ print(type(self.messages))
83
+ print(self.messages)
ChatHaruhi/BaiChuanAPIGPT.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import hashlib
5
+ import requests
6
+ import copy
7
+
8
+ from .BaseLLM import BaseLLM
9
+
10
+ BAICHUAN_API_AK = os.getenv("BAICHUAN_API_AK")
11
+ BAICHUAN_API_SK = os.getenv("BAICHUAN_API_SK")
12
+
13
+ def sign(secret_key, data):
14
+ json_data = json.dumps(data)
15
+ time_stamp = int(time.time())
16
+ input_string = secret_key + json_data + str(time_stamp)
17
+ md5 = hashlib.md5()
18
+ md5.update(input_string.encode('utf-8'))
19
+ encrypted = md5.hexdigest()
20
+ return encrypted
21
+
22
+ def do_request(messages, api_key, secret_key):
23
+ url = "https://api.baichuan-ai.com/v1/chat"
24
+
25
+ data = {
26
+ "model": "Baichuan2-53B",
27
+ "messages": messages
28
+ }
29
+
30
+ signature = sign(secret_key, data)
31
+
32
+ headers = {
33
+ "Content-Type": "application/json",
34
+ "Authorization": "Bearer " + api_key,
35
+ "X-BC-Request-Id": "your requestId",
36
+ "X-BC-Timestamp": str(int(time.time())),
37
+ "X-BC-Signature": signature,
38
+ "X-BC-Sign-Algo": "MD5",
39
+ }
40
+
41
+ response = requests.post(url, data=json.dumps(data), headers=headers)
42
+ if response.status_code == 200:
43
+ return response.json()
44
+ else:
45
+ return None
46
+
47
+ class BaiChuanAPIGPT(BaseLLM):
48
+ def __init__(self, model="baichuan-api", api_key=None, secret_key=None, verbose=False, if_trick = True):
49
+ self.if_trick = if_trick
50
+ super(BaiChuanAPIGPT, self).__init__()
51
+ self.api_key = api_key or BAICHUAN_API_AK
52
+ self.secret_key = secret_key or BAICHUAN_API_SK
53
+ self.verbose = verbose
54
+ self.model_name = model
55
+ self.messages = []
56
+ if self.verbose:
57
+ print('model name, ', self.model_name)
58
+ if self.api_key is None or self.secret_key is None:
59
+ print('Please set BAICHUAN_API_AK and BAICHUAN_API_SK')
60
+
61
+ def initialize_message(self):
62
+ self.messages = []
63
+
64
+
65
+ def ai_message(self, payload):
66
+ if len(self.messages) == 0:
67
+ self.user_message("请根据我的要求进行角色扮演:")
68
+ elif len(self.messages) % 2 == 1:
69
+ self.messages.append({"role":"assistant","content":payload})
70
+ elif len(self.messages)% 2 == 0:
71
+ self.messages[-1]["content"] += "\n"+ payload
72
+
73
+ def system_message(self, payload):
74
+
75
+ self.messages.append({"role":"user","content":payload})
76
+
77
+
78
+ def user_message(self, payload):
79
+ if len(self.messages) % 2 == 0:
80
+ self.messages.append({"role":"user","content":payload})
81
+ # self.messages[-1]["content"] +=
82
+ elif len(self.messages)% 2 == 1:
83
+ self.messages[-1]["content"] += "\n"+ payload
84
+
85
+ def get_response(self):
86
+ max_try = 5
87
+ sleep_interval = 3
88
+
89
+ chat_messages = copy.deepcopy(self.messages)
90
+
91
+ if self.if_trick == True:
92
+ lines = chat_messages[-1]["content"].split('\n')
93
+ lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
94
+ chat_messages[-1]["content"] = '\n'.join(lines)
95
+
96
+ for i in range(max_try):
97
+ response = do_request(chat_messages, self.api_key, self.secret_key)
98
+ if response is not None:
99
+ if self.verbose:
100
+ print('Get Baichuan API response success')
101
+ messages = response['data']['messages']
102
+ if len(messages) > 0:
103
+ return messages[-1]['content'].strip("\"'")
104
+ else:
105
+ if self.verbose:
106
+ print('Get Baichuan API response failed, retrying...')
107
+ time.sleep(sleep_interval)
108
+
109
+ def print_prompt(self):
110
+ for message in self.messages:
111
+ print(f"{message['role']}: {message['content']}")
112
+
ChatHaruhi/BaseDB.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BaseDB.py
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ class BaseDB(ABC):
6
+
7
+ @abstractmethod
8
+ def init_db(self):
9
+ pass
10
+
11
+ @abstractmethod
12
+ def save(self, file_path):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def load(self, file_path):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def search(self, vector, n_results):
21
+ pass
22
+
23
+ @abstractmethod
24
+ def init_from_docs(self, vectors, documents):
25
+ pass
26
+
27
+
ChatHaruhi/BaseLLM.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
2
+ #
3
+ # ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
4
+ #
5
+ # chengli.thu@gmail.com, mws22@mails.tsinghua.edu.cn
6
+ #
7
+ # Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
8
+ # Weishi Mi is pursuing a job or a PhD position, which who will be available next year
9
+ #
10
+ # homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
11
+ #
12
+ # ChatHaruhi is a chatbot that can revive anime characters in reality.
13
+ # the 2.0 version was built by Cheng Li and Weishi Mi.
14
+ #
15
+ # Please cite our paper if you use this code for research:
16
+ #
17
+ # @misc{li2023chatharuhi,
18
+ # title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
19
+ # author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
20
+ # year={2023},
21
+ # eprint={2308.09597},
22
+ # archivePrefix={arXiv},
23
+ # primaryClass={cs.CL}
24
+ # }
25
+ from abc import ABC, abstractmethod
26
+
27
+ class BaseLLM(ABC):
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ @abstractmethod
33
+ def initialize_message(self):
34
+ pass
35
+
36
+ @abstractmethod
37
+ def ai_message(self, payload):
38
+ pass
39
+
40
+ @abstractmethod
41
+ def system_message(self, payload):
42
+ pass
43
+
44
+ @abstractmethod
45
+ def user_message(self, payload):
46
+ pass
47
+
48
+ @abstractmethod
49
+ def get_response(self):
50
+ pass
51
+
52
+ @abstractmethod
53
+ def print_prompt(self):
54
+ pass
55
+
56
+
ChatHaruhi/ChatGLM2GPT.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .BaseLLM import BaseLLM
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from peft import PeftModel
5
+
6
+ tokenizer_GLM = None
7
+ model_GLM = None
8
+
9
+ def initialize_GLM2LORA():
10
+ global model_GLM, tokenizer_GLM
11
+
12
+ if model_GLM is None:
13
+ model_GLM = AutoModel.from_pretrained(
14
+ "THUDM/chatglm2-6b",
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ trust_remote_code=True
18
+ )
19
+ model_GLM = PeftModel.from_pretrained(
20
+ model_GLM,
21
+ "silk-road/Chat-Haruhi-Fusion_B"
22
+ )
23
+
24
+ if tokenizer_GLM is None:
25
+ tokenizer_GLM = AutoTokenizer.from_pretrained(
26
+ "THUDM/chatglm2-6b",
27
+ use_fast=True,
28
+ trust_remote_code=True
29
+ )
30
+
31
+ return model_GLM, tokenizer_GLM
32
+
33
+ def GLM_tokenizer(text):
34
+ return len(tokenizer_GLM.encode(text))
35
+
36
+ class ChatGLM2GPT(BaseLLM):
37
+ def __init__(self, model = "haruhi-fusion"):
38
+ super(ChatGLM2GPT, self).__init__()
39
+ if model == "glm2-6b":
40
+ self.tokenizer = AutoTokenizer.from_pretrained(
41
+ "THUDM/chatglm2-6b",
42
+ use_fast=True,
43
+ trust_remote_code=True
44
+ )
45
+ self.model = AutoModel.from_pretrained(
46
+ "THUDM/chatglm2-6b",
47
+ torch_dtype=torch.float16,
48
+ device_map="auto",
49
+ trust_remote_code=True
50
+ )
51
+ if model == "haruhi-fusion":
52
+ self.model, self.tokenizer = initialize_GLM2LORA()
53
+ else:
54
+ raise Exception("Unknown GLM model")
55
+ self.messages = ""
56
+
57
+ def initialize_message(self):
58
+ self.messages = ""
59
+
60
+ def ai_message(self, payload):
61
+ self.messages = self.messages + "\n " + payload
62
+
63
+ def system_message(self, payload):
64
+ self.messages = self.messages + "\n " + payload
65
+
66
+ def user_message(self, payload):
67
+ self.messages = self.messages + "\n " + payload
68
+
69
+ def get_response(self):
70
+ with torch.no_grad():
71
+ response, history = self.model.chat(self.tokenizer, self.messages, history=[])
72
+ # print(response)
73
+ return response
74
+
75
+ def print_prompt(self):
76
+ print(type(self.messages))
77
+ print(self.messages)
78
+
79
+
ChatHaruhi/ChatHaruhi.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ChromaDB import ChromaDB
2
+ import os
3
+
4
+ from .utils import luotuo_openai_embedding, tiktokenizer
5
+
6
+ from .utils import response_postprocess
7
+
8
+ def get_text_from_data( data ):
9
+ if "text" in data:
10
+ return data['text']
11
+ elif "enc_text" in data:
12
+ from .utils import base64_to_string
13
+ return base64_to_string( data['enc_text'] )
14
+ else:
15
+ print("warning! failed to get text from data ", data)
16
+ return ""
17
+
18
+ class ChatHaruhi:
19
+
20
+ def __init__(self, system_prompt = None, \
21
+ role_name = None, role_from_hf = None,
22
+ role_from_jsonl = None, \
23
+ story_db=None, story_text_folder = None, \
24
+ llm = 'openai', \
25
+ embedding = 'luotuo_openai', \
26
+ max_len_story = None, max_len_history = None,
27
+ verbose = False):
28
+ super(ChatHaruhi, self).__init__()
29
+ self.verbose = verbose
30
+
31
+ # constants
32
+ self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
33
+ self.k_search = 19
34
+ self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
35
+ self.dialogue_divide_token = '\n###\n'
36
+ self.dialogue_bra_token = '「'
37
+ self.dialogue_ket_token = '」'
38
+
39
+ if system_prompt:
40
+ self.system_prompt = self.check_system_prompt( system_prompt )
41
+
42
+ # TODO: embedding should be the seperately defined, so refactor this part later
43
+ if llm == 'openai':
44
+ # self.llm = LangChainGPT()
45
+ self.llm, self.tokenizer = self.get_models('openai')
46
+ elif llm == 'debug':
47
+ self.llm, self.tokenizer = self.get_models('debug')
48
+ elif llm == 'spark':
49
+ self.llm, self.tokenizer = self.get_models('spark')
50
+ elif llm == 'GLMPro':
51
+ self.llm, self.tokenizer = self.get_models('GLMPro')
52
+ elif llm == 'ChatGLM2GPT':
53
+ self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
54
+ self.story_prefix_prompt = '\n'
55
+ elif llm == "BaiChuan2GPT":
56
+ self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
57
+ elif llm == "BaiChuanAPIGPT":
58
+ self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
59
+ elif llm == "ernie3.5":
60
+ self.llm, self.tokenizer = self.get_models('ernie3.5')
61
+ elif llm == "ernie4.0":
62
+ self.llm, self.tokenizer = self.get_models('ernie4.0')
63
+ elif "qwen" in llm:
64
+ self.llm, self.tokenizer = self.get_models(llm)
65
+ else:
66
+ print(f'warning! undefined llm {llm}, use openai instead.')
67
+ self.llm, self.tokenizer = self.get_models('openai')
68
+
69
+ if embedding == 'luotuo_openai':
70
+ self.embedding = luotuo_openai_embedding
71
+ elif embedding == 'bge_en':
72
+ from .utils import get_bge_embedding
73
+ self.embedding = get_bge_embedding
74
+ elif embedding == 'bge_zh':
75
+ from .utils import get_bge_zh_embedding
76
+ self.embedding = get_bge_zh_embedding
77
+ else:
78
+ print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
79
+ self.embedding = luotuo_openai_embedding
80
+
81
+ if role_name:
82
+ # TODO move into a function
83
+ from .role_name_to_file import get_folder_role_name
84
+ # correct role_name to folder_role_name
85
+ role_name, url = get_folder_role_name(role_name)
86
+
87
+ unzip_folder = f'./temp_character_folder/temp_{role_name}'
88
+ db_folder = os.path.join(unzip_folder, f'content/{role_name}')
89
+ system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')
90
+
91
+ if not os.path.exists(unzip_folder):
92
+ # not yet downloaded
93
+ # url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
94
+ import requests, zipfile, io
95
+ r = requests.get(url)
96
+ z = zipfile.ZipFile(io.BytesIO(r.content))
97
+ z.extractall(unzip_folder)
98
+
99
+ if self.verbose:
100
+ print(f'loading pre-defined character {role_name}...')
101
+
102
+ self.db = ChromaDB()
103
+ self.db.load(db_folder)
104
+ self.system_prompt = self.check_system_prompt(system_prompt)
105
+ elif role_from_hf:
106
+ # TODO move into a function
107
+ from datasets import load_dataset
108
+
109
+ if role_from_hf.count("/") == 1:
110
+ dataset = load_dataset(role_from_hf)
111
+ datas = dataset["train"]
112
+ elif role_from_hf.count("/") >= 2:
113
+ split_index = role_from_hf.index('/')
114
+ second_split_index = role_from_hf.index('/', split_index+1)
115
+ dataset_name = role_from_hf[:second_split_index]
116
+ split_name = role_from_hf[second_split_index+1:]
117
+
118
+ fname = split_name + '.jsonl'
119
+ dataset = load_dataset(dataset_name,data_files={'train':fname})
120
+ datas = dataset["train"]
121
+
122
+ if embedding == 'luotuo_openai':
123
+ embed_name = 'luotuo_openai'
124
+ elif embedding == 'bge_en':
125
+ embed_name = 'bge_en_s15'
126
+ elif embedding == 'bge_zh':
127
+ embed_name = 'bge_zh_s15'
128
+ else:
129
+ print('warning! unkown embedding name ', embedding ,' while loading role')
130
+ embed_name = 'luotuo_openai'
131
+
132
+ texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)
133
+
134
+ self.build_story_db_from_vec( texts, vecs )
135
+
136
+ elif role_from_jsonl:
137
+ import json
138
+ datas = []
139
+ with open( role_from_jsonl , encoding="utf-8") as f:
140
+ for line in f:
141
+ try:
142
+ data = json.loads(line)
143
+ # 逐行处理JSON数据
144
+ datas.append(data)
145
+ except:
146
+ print("warning! failed to load json line ", line)
147
+
148
+ if embedding == 'luotuo_openai':
149
+ embed_name = 'luotuo_openai'
150
+ elif embedding == 'bge_en':
151
+ embed_name = 'bge_en_s15'
152
+ elif embedding == 'bge_zh':
153
+ embed_name = 'bge_zh_s15'
154
+ else:
155
+ print('warning! unkown embedding name ', embedding ,' while loading role')
156
+ embed_name = 'luotuo_openai'
157
+
158
+ texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name)
159
+
160
+ self.build_story_db_from_vec( texts, vecs )
161
+
162
+ elif story_db:
163
+ self.db = ChromaDB()
164
+ self.db.load(story_db)
165
+ elif story_text_folder:
166
+ # print("Building story database from texts...")
167
+ self.db = self.build_story_db(story_text_folder)
168
+ else:
169
+ self.db = None
170
+ print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
171
+ # raise ValueError("Either story_db or story_text_folder must be provided")
172
+
173
+
174
+ self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')
175
+
176
+ if max_len_history is not None:
177
+ self.max_len_history = max_len_history
178
+ # user setting will override default setting
179
+
180
+ if max_len_story is not None:
181
+ self.max_len_story = max_len_story
182
+ # user setting will override default setting
183
+
184
+ self.dialogue_history = []
185
+
186
+ def extract_text_vec_from_datas( self, datas, embed_name ):
187
+ # extract text and vec from huggingface dataset
188
+ # return texts, vecs
189
+ from .utils import base64_to_float_array
190
+
191
+ texts = []
192
+ vecs = []
193
+ for data in datas:
194
+ if data[embed_name] == 'system_prompt':
195
+ system_prompt = get_text_from_data( data )
196
+ elif data[embed_name] == 'config':
197
+ pass
198
+ else:
199
+ vec = base64_to_float_array( data[embed_name] )
200
+ text = get_text_from_data( data )
201
+ vecs.append( vec )
202
+ texts.append( text )
203
+ return texts, vecs, system_prompt
204
+
205
+
206
+
207
+ def check_system_prompt(self, system_prompt):
208
+ # if system_prompt end with .txt, read the file with utf-8
209
+ # else, return the string directly
210
+ if system_prompt.endswith('.txt'):
211
+ with open(system_prompt, 'r', encoding='utf-8') as f:
212
+ return f.read()
213
+ else:
214
+ return system_prompt
215
+
216
+
217
+ def get_models(self, model_name):
218
+
219
+ # TODO: if output only require tokenizer model, no need to initialize llm
220
+
221
+ # return the combination of llm, embedding and tokenizer
222
+ if model_name == 'openai':
223
+ from .LangChainGPT import LangChainGPT
224
+ return (LangChainGPT(), tiktokenizer)
225
+ elif model_name == 'debug':
226
+ from .PrintLLM import PrintLLM
227
+ return (PrintLLM(), tiktokenizer)
228
+ elif model_name == 'spark':
229
+ from .SparkGPT import SparkGPT
230
+ return (SparkGPT(), tiktokenizer)
231
+ elif model_name == 'GLMPro':
232
+ from .GLMPro import GLMPro
233
+ return (GLMPro(), tiktokenizer)
234
+ elif model_name == 'ernie3.5':
235
+ from .ErnieGPT import ErnieGPT
236
+ return (ErnieGPT(), tiktokenizer)
237
+ elif model_name == 'ernie4.0':
238
+ from .ErnieGPT import ErnieGPT
239
+ return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
240
+ elif model_name == "ChatGLM2GPT":
241
+ from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
242
+ return (ChatGLM2GPT(), GLM_tokenizer)
243
+ elif model_name == "BaiChuan2GPT":
244
+ from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
245
+ return (BaiChuan2GPT(), BaiChuan_tokenizer)
246
+ elif model_name == "BaiChuanAPIGPT":
247
+ from .BaiChuanAPIGPT import BaiChuanAPIGPT
248
+ return (BaiChuanAPIGPT(), tiktokenizer)
249
+ elif "qwen" in model_name:
250
+ if model_name == "qwen118k_raw":
251
+ from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
252
+ return (Qwen118k2GPT(model = "Qwen/Qwen-1_8B-Chat"), Qwen_tokenizer)
253
+ from huggingface_hub import HfApi
254
+ from huggingface_hub.hf_api import ModelFilter
255
+ qwen_api = HfApi()
256
+ qwen_models = qwen_api.list_models(
257
+ filter = ModelFilter(model_name=model_name),
258
+ author = "silk-road"
259
+ )
260
+ qwen_models_id = []
261
+ for qwen_model in qwen_models:
262
+ qwen_models_id.append(qwen_model.id)
263
+ # print(model.id)
264
+ if "silk-road/" + model_name in qwen_models_id:
265
+ from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer
266
+ return (Qwen118k2GPT(model = "silk-road/" + model_name), Qwen_tokenizer)
267
+ else:
268
+ print(f'warning! undefined model {model_name}, use openai instead.')
269
+ from .LangChainGPT import LangChainGPT
270
+ return (LangChainGPT(), tiktokenizer)
271
+ # print(models_id)
272
+ else:
273
+ print(f'warning! undefined model {model_name}, use openai instead.')
274
+ from .LangChainGPT import LangChainGPT
275
+ return (LangChainGPT(), tiktokenizer)
276
+
277
+ def get_tokenlen_setting( self, model_name ):
278
+ # return the setting of story and history token length
279
+ if model_name == 'openai':
280
+ return (1500, 1200)
281
+ else:
282
+ print(f'warning! undefined model {model_name}, use openai instead.')
283
+ return (1500, 1200)
284
+
285
+ def build_story_db_from_vec( self, texts, vecs ):
286
+ self.db = ChromaDB()
287
+
288
+ self.db.init_from_docs( vecs, texts)
289
+
290
+ def build_story_db(self, text_folder):
291
+ # 实现读取文本文件夹,抽取向量的逻辑
292
+ db = ChromaDB()
293
+
294
+ strs = []
295
+
296
+ # scan all txt file from text_folder
297
+ for file in os.listdir(text_folder):
298
+ # if file name end with txt
299
+ if file.endswith(".txt"):
300
+ file_path = os.path.join(text_folder, file)
301
+ with open(file_path, 'r', encoding='utf-8') as f:
302
+ strs.append(f.read())
303
+
304
+ if self.verbose:
305
+ print(f'starting extract embedding... for { len(strs) } files')
306
+
307
+ vecs = []
308
+
309
+ ## TODO: 建立一个新的embedding batch test的单元测试
310
+ ## 新的支持list batch test的embedding代码
311
+ ## 用新的代码替换下面的for循环
312
+ ## Luotuo-bert-en也发布了,所以可以避开使用openai
313
+
314
+ for mystr in strs:
315
+ vecs.append(self.embedding(mystr))
316
+
317
+ db.init_from_docs(vecs, strs)
318
+
319
+ return db
320
+
321
+ def save_story_db(self, db_path):
322
+ self.db.save(db_path)
323
+
324
+ def generate_prompt( self, text, role):
325
+ from langchain.schema import (
326
+ AIMessage,
327
+ HumanMessage,
328
+ SystemMessage
329
+ )
330
+ messages = self.generate_messages( text, role )
331
+ prompt = ""
332
+ for msg in messages:
333
+ if isinstance(msg, HumanMessage):
334
+ prompt += msg.content + "\n"
335
+ elif isinstance(msg, AIMessage):
336
+ prompt += msg.content + "\n"
337
+ elif isinstance(msg, SystemMessage):
338
+ prompt += msg.content + "\n"
339
+ return prompt
340
+
341
+
342
+ def generate_messages( self, text, role):
343
+ # add system prompt
344
+ self.llm.initialize_message()
345
+ self.llm.system_message(self.system_prompt)
346
+
347
+ # add story
348
+ query = self.get_query_string(text, role)
349
+ self.add_story( query )
350
+ self.last_query = query
351
+
352
+ # add query
353
+ self.llm.user_message(query)
354
+
355
+ return self.llm.messages
356
+
357
+ def append_response( self, response, last_query = None ):
358
+ if last_query == None:
359
+ last_query_record = ""
360
+ if hasattr( self, "last_query" ):
361
+ last_query_record = self.last_query
362
+ else:
363
+ last_query_record = last_query
364
+
365
+ # record dialogue history
366
+ self.dialogue_history.append((last_query_record, response))
367
+
368
+ def chat(self, text, role):
369
+ # add system prompt
370
+ self.llm.initialize_message()
371
+ self.llm.system_message(self.system_prompt)
372
+
373
+
374
+ # add story
375
+ query = self.get_query_string(text, role)
376
+ self.add_story( query )
377
+
378
+ # add history
379
+ self.add_history()
380
+
381
+ # add query
382
+ self.llm.user_message(query)
383
+
384
+ # get response
385
+ response_raw = self.llm.get_response()
386
+
387
+ response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)
388
+
389
+ # record dialogue history
390
+ self.dialogue_history.append((query, response))
391
+
392
+
393
+
394
+ return response
395
+
396
+ def get_query_string(self, text, role):
397
+ if role in self.narrator:
398
+ return role + ":" + text
399
+ else:
400
+ return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
401
+
402
+ def add_story(self, query):
403
+
404
+ if self.db is None:
405
+ return
406
+
407
+ query_vec = self.embedding(query)
408
+
409
+ stories = self.db.search(query_vec, self.k_search)
410
+
411
+ story_string = self.story_prefix_prompt
412
+ sum_story_token = self.tokenizer(story_string)
413
+
414
+ for story in stories:
415
+ story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
416
+ if sum_story_token + story_token > self.max_len_story:
417
+ break
418
+ else:
419
+ sum_story_token += story_token
420
+ story_string += story + self.dialogue_divide_token
421
+
422
+ self.llm.user_message(story_string)
423
+
424
+ def add_history(self):
425
+
426
+ if len(self.dialogue_history) == 0:
427
+ return
428
+
429
+ sum_history_token = 0
430
+ flag = 0
431
+ for query, response in reversed(self.dialogue_history):
432
+ current_count = 0
433
+ if query is not None:
434
+ current_count += self.tokenizer(query)
435
+ if response is not None:
436
+ current_count += self.tokenizer(response)
437
+ sum_history_token += current_count
438
+ if sum_history_token > self.max_len_history:
439
+ break
440
+ else:
441
+ flag += 1
442
+
443
+ if flag == 0:
444
+ print('warning! no history added. the last dialogue is too long.')
445
+
446
+ for (query, response) in self.dialogue_history[-flag:]:
447
+ if query is not None:
448
+ self.llm.user_message(query)
449
+ if response is not None:
450
+ self.llm.ai_message(response)
ChatHaruhi/ChatHaruhi_safe.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ChromaDB import ChromaDB
2
+ import os
3
+
4
+ from .utils import luotuo_openai_embedding, tiktokenizer
5
+
6
+ from .utils import response_postprocess
7
+
8
+ from .utils import text_censor
9
+
10
+ class ChatHaruhi_safe:
11
+
12
+ def __init__(self, system_prompt = None, \
13
+ role_name = None, role_from_hf = None, \
14
+ story_db=None, story_text_folder = None, \
15
+ llm = 'openai', \
16
+ embedding = 'luotuo_openai', \
17
+ max_len_story = None, max_len_history = None,
18
+ verbose = False):
19
+ super(ChatHaruhi_safe, self).__init__()
20
+ self.verbose = verbose
21
+
22
+ # constants
23
+ self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
24
+ self.k_search = 19
25
+ self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
26
+ self.dialogue_divide_token = '\n###\n'
27
+ self.dialogue_bra_token = '「'
28
+ self.dialogue_ket_token = '」'
29
+
30
+ if system_prompt:
31
+ self.system_prompt = self.check_system_prompt( system_prompt )
32
+
33
+ # TODO: embedding should be the seperately defined, so refactor this part later
34
+ if llm == 'openai':
35
+ # self.llm = LangChainGPT()
36
+ self.llm, self.tokenizer = self.get_models('openai')
37
+ elif llm == 'debug':
38
+ self.llm, self.tokenizer = self.get_models('debug')
39
+ elif llm == 'spark':
40
+ self.llm, self.tokenizer = self.get_models('spark')
41
+ elif llm == 'GLMPro':
42
+ self.llm, self.tokenizer = self.get_models('GLMPro')
43
+ elif llm == 'ChatGLM2GPT':
44
+ self.llm, self.tokenizer = self.get_models('ChatGLM2GPT')
45
+ self.story_prefix_prompt = '\n'
46
+ elif llm == "BaiChuan2GPT":
47
+ self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
48
+ elif llm == "BaiChuanAPIGPT":
49
+ self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT')
50
+ elif llm == "ernie3.5":
51
+ self.llm, self.tokenizer = self.get_models('ernie3.5')
52
+ elif llm == "ernie4.0":
53
+ self.llm, self.tokenizer = self.get_models('ernie4.0')
54
+ else:
55
+ print(f'warning! undefined llm {llm}, use openai instead.')
56
+ self.llm, self.tokenizer = self.get_models('openai')
57
+
58
+ if embedding == 'luotuo_openai':
59
+ self.embedding = luotuo_openai_embedding
60
+ elif embedding == 'bge_en':
61
+ from .utils import get_bge_embedding
62
+ self.embedding = get_bge_embedding
63
+ else:
64
+ print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.')
65
+ self.embedding = luotuo_openai_embedding
66
+
67
+ if role_name:
68
+ # TODO move into a function
69
+ from .role_name_to_file import get_folder_role_name
70
+ # correct role_name to folder_role_name
71
+ role_name, url = get_folder_role_name(role_name)
72
+
73
+ unzip_folder = f'./temp_character_folder/temp_{role_name}'
74
+ db_folder = os.path.join(unzip_folder, f'content/{role_name}')
75
+ system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt')
76
+
77
+ if not os.path.exists(unzip_folder):
78
+ # not yet downloaded
79
+ # url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
80
+ import requests, zipfile, io
81
+ r = requests.get(url)
82
+ z = zipfile.ZipFile(io.BytesIO(r.content))
83
+ z.extractall(unzip_folder)
84
+
85
+ if self.verbose:
86
+ print(f'loading pre-defined character {role_name}...')
87
+
88
+ self.db = ChromaDB()
89
+ self.db.load(db_folder)
90
+ self.system_prompt = self.check_system_prompt(system_prompt)
91
+ elif role_from_hf:
92
+ # TODO move into a function
93
+ from datasets import load_dataset
94
+
95
+ if role_from_hf.count("/") == 1:
96
+ dataset = load_dataset(role_from_hf)
97
+ datas = dataset["train"]
98
+ elif role_from_hf.count("/") >= 2:
99
+ split_index = role_from_hf.index('/')
100
+ second_split_index = role_from_hf.index('/', split_index+1)
101
+ dataset_name = role_from_hf[:second_split_index]
102
+ split_name = role_from_hf[second_split_index+1:]
103
+
104
+ fname = split_name + '.jsonl'
105
+ dataset = load_dataset(dataset_name,data_files={'train':fname})
106
+ datas = dataset["train"]
107
+
108
+
109
+ from .utils import base64_to_float_array
110
+
111
+ if embedding == 'luotuo_openai':
112
+ embed_name = 'luotuo_openai'
113
+ elif embedding == 'bge_en':
114
+ embed_name = 'bge_en_s15'
115
+ else:
116
+ print('warning! unkown embedding name ', embedding ,' while loading role')
117
+ embed_name = 'luotuo_openai'
118
+
119
+ texts = []
120
+ vecs = []
121
+ for data in datas:
122
+ if data[embed_name] == 'system_prompt':
123
+ self.system_prompt = data['text']
124
+ elif data[embed_name] == 'config':
125
+ pass
126
+ else:
127
+ vec = base64_to_float_array( data[embed_name] )
128
+ text = data['text']
129
+ vecs.append( vec )
130
+ texts.append( text )
131
+
132
+ self.build_story_db_from_vec( texts, vecs )
133
+
134
+ elif story_db:
135
+ self.db = ChromaDB()
136
+ self.db.load(story_db)
137
+ elif story_text_folder:
138
+ # print("Building story database from texts...")
139
+ self.db = self.build_story_db(story_text_folder)
140
+ else:
141
+ self.db = None
142
+ print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
143
+ # raise ValueError("Either story_db or story_text_folder must be provided")
144
+
145
+
146
+ self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')
147
+
148
+ if max_len_history is not None:
149
+ self.max_len_history = max_len_history
150
+ # user setting will override default setting
151
+
152
+ if max_len_story is not None:
153
+ self.max_len_story = max_len_story
154
+ # user setting will override default setting
155
+
156
+ self.dialogue_history = []
157
+
158
+
159
+
160
+ def check_system_prompt(self, system_prompt):
161
+ # if system_prompt end with .txt, read the file with utf-8
162
+ # else, return the string directly
163
+ if system_prompt.endswith('.txt'):
164
+ with open(system_prompt, 'r', encoding='utf-8') as f:
165
+ return f.read()
166
+ else:
167
+ return system_prompt
168
+
169
+
170
+ def get_models(self, model_name):
171
+
172
+ # TODO: if output only require tokenizer model, no need to initialize llm
173
+
174
+ # return the combination of llm, embedding and tokenizer
175
+ if model_name == 'openai':
176
+ from .LangChainGPT import LangChainGPT
177
+ return (LangChainGPT(), tiktokenizer)
178
+ elif model_name == 'debug':
179
+ from .PrintLLM import PrintLLM
180
+ return (PrintLLM(), tiktokenizer)
181
+ elif model_name == 'spark':
182
+ from .SparkGPT import SparkGPT
183
+ return (SparkGPT(), tiktokenizer)
184
+ elif model_name == 'GLMPro':
185
+ from .GLMPro import GLMPro
186
+ return (GLMPro(), tiktokenizer)
187
+ elif model_name == 'ernie3.5':
188
+ from .ErnieGPT import ErnieGPT
189
+ return (ErnieGPT(), tiktokenizer)
190
+ elif model_name == 'ernie4.0':
191
+ from .ErnieGPT import ErnieGPT
192
+ return (ErnieGPT(model="ernie-bot-4"), tiktokenizer)
193
+ elif model_name == "ChatGLM2GPT":
194
+ from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer
195
+ return (ChatGLM2GPT(), GLM_tokenizer)
196
+ elif model_name == "BaiChuan2GPT":
197
+ from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
198
+ return (BaiChuan2GPT(), BaiChuan_tokenizer)
199
+ elif model_name == "BaiChuanAPIGPT":
200
+ from .BaiChuanAPIGPT import BaiChuanAPIGPT
201
+ return (BaiChuanAPIGPT(), tiktokenizer)
202
+ else:
203
+ print(f'warning! undefined model {model_name}, use openai instead.')
204
+ from .LangChainGPT import LangChainGPT
205
+ return (LangChainGPT(), tiktokenizer)
206
+
207
+ def get_tokenlen_setting( self, model_name ):
208
+ # return the setting of story and history token length
209
+ if model_name == 'openai':
210
+ return (1500, 1200)
211
+ else:
212
+ print(f'warning! undefined model {model_name}, use openai instead.')
213
+ return (1500, 1200)
214
+
215
+ def build_story_db_from_vec( self, texts, vecs ):
216
+ self.db = ChromaDB()
217
+
218
+ self.db.init_from_docs( vecs, texts)
219
+
220
+ def build_story_db(self, text_folder):
221
+ # 实现读取文本文件夹,抽取向量的逻辑
222
+ db = ChromaDB()
223
+
224
+ strs = []
225
+
226
+ # scan all txt file from text_folder
227
+ for file in os.listdir(text_folder):
228
+ # if file name end with txt
229
+ if file.endswith(".txt"):
230
+ file_path = os.path.join(text_folder, file)
231
+ with open(file_path, 'r', encoding='utf-8') as f:
232
+ strs.append(f.read())
233
+
234
+ if self.verbose:
235
+ print(f'starting extract embedding... for { len(strs) } files')
236
+
237
+ vecs = []
238
+
239
+ ## TODO: 建立一个新的embedding batch test的单元测试
240
+ ## 新的支持list batch test的embedding代码
241
+ ## 用新的代码替换下面的for循环
242
+ ## Luotuo-bert-en也发布了,所以可以避开使用openai
243
+
244
+ for mystr in strs:
245
+ vecs.append(self.embedding(mystr))
246
+
247
+ db.init_from_docs(vecs, strs)
248
+
249
+ return db
250
+
251
+ def save_story_db(self, db_path):
252
+ self.db.save(db_path)
253
+
254
+ def chat(self, text, role):
255
+ # add system prompt
256
+ self.llm.initialize_message()
257
+ self.llm.system_message(self.system_prompt)
258
+
259
+
260
+ # add story
261
+ query = self.get_query_string(text, role)
262
+ self.add_story( query )
263
+
264
+ # add history
265
+ self.add_history()
266
+
267
+ # add query
268
+ self.llm.user_message(query)
269
+
270
+ # get response
271
+ response_raw = self.llm.get_response()
272
+
273
+ response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)
274
+
275
+ # record dialogue history
276
+ self.dialogue_history.append((query, response))
277
+
278
+
279
+
280
+ return response
281
+
282
+ def get_query_string(self, text, role):
283
+ if role in self.narrator:
284
+ return role + ":" + text
285
+ else:
286
+ return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}"
287
+
288
+ def add_story(self, query):
289
+
290
+ if self.db is None:
291
+ return
292
+
293
+ query_vec = self.embedding(query)
294
+
295
+ stories = self.db.search(query_vec, self.k_search)
296
+
297
+ story_string = self.story_prefix_prompt
298
+ sum_story_token = self.tokenizer(story_string)
299
+
300
+ for story in stories:
301
+ story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)
302
+ if sum_story_token + story_token > self.max_len_story:
303
+ break
304
+ else:
305
+ sum_story_token += story_token
306
+ story_string += story + self.dialogue_divide_token
307
+
308
+ if text_censor(story_string):
309
+ self.llm.user_message(story_string)
310
+
311
+ def add_history(self):
312
+
313
+ if len(self.dialogue_history) == 0:
314
+ return
315
+
316
+ sum_history_token = 0
317
+ flag = 0
318
+ for query, response in reversed(self.dialogue_history):
319
+ current_count = 0
320
+ if query is not None:
321
+ current_count += self.tokenizer(query)
322
+ if response is not None:
323
+ current_count += self.tokenizer(response)
324
+ sum_history_token += current_count
325
+ if sum_history_token > self.max_len_history:
326
+ break
327
+ else:
328
+ flag += 1
329
+
330
+ if flag == 0:
331
+ print('warning! no history added. the last dialogue is too long.')
332
+
333
+ for (query, response) in self.dialogue_history[-flag:]:
334
+ if query is not None:
335
+ self.llm.user_message(query)
336
+ if response is not None:
337
+ self.llm.ai_message(response)
ChatHaruhi/ChromaDB.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from .BaseDB import BaseDB
3
+ import random
4
+ import string
5
+ import os
6
+
7
+ class ChromaDB(BaseDB):
8
+
9
+ def __init__(self):
10
+ self.client = None
11
+ self.collection = None
12
+ self.path = None
13
+
14
+ def init_db(self):
15
+
16
+ if self.client is not None:
17
+ print('ChromaDB has already been initialized')
18
+ return
19
+
20
+ folder_name = ''
21
+
22
+ while os.path.exists(folder_name) or folder_name == '':
23
+ # try to create a folder named temp_<random string> which is not yet existed
24
+ folder_name = "tempdb_" + ''.join(random.sample(string.ascii_letters + string.digits, 8))
25
+
26
+ self.path = folder_name
27
+ self.client = chromadb.PersistentClient(path = folder_name)
28
+
29
+ self.collection = self.client.get_or_create_collection("search")
30
+
31
+ def save(self, file_path):
32
+ if file_path != self.path:
33
+ # copy all files in self.path to file_path, with overwrite
34
+ os.system("cp -r " + self.path + " " + file_path)
35
+ previous_path = self.path
36
+ self.path = file_path
37
+ self.client = chromadb.PersistentClient(path = file_path)
38
+ # remove previous path if it start with tempdb
39
+ if previous_path.startswith("tempdb"):
40
+ os.system("rm -rf " + previous_path)
41
+
42
+
43
+ def load(self, file_path):
44
+ self.path = file_path
45
+ self.client = chromadb.PersistentClient(path = file_path)
46
+ self.collection = self.client.get_collection("search")
47
+
48
+ def search(self, vector, n_results):
49
+ results = self.collection.query(query_embeddings=[vector], n_results=n_results)
50
+ return results['documents'][0]
51
+
52
+ def init_from_docs(self, vectors, documents):
53
+ if self.client is None:
54
+ self.init_db()
55
+
56
+ ids = []
57
+ for i, doc in enumerate(documents):
58
+ first_four_chat = doc[:min(4, len(doc))]
59
+ ids.append( str(i) + "_" + doc)
60
+ self.collection.add(embeddings=vectors, documents=documents, ids = ids)
61
+
ChatHaruhi/ErnieGPT.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ErnieGPT.py
2
+ from pyexpat import model
3
+ import erniebot
4
+ #以下密钥信息从os环境获取
5
+ import os
6
+ import copy
7
+
8
+ # appid = os.environ['APPID']
9
+ # api_secret = os.environ['APISecret']
10
+ # api_key = os.environ['APIKey']
11
+ erniebot.api_type = os.environ["APIType"]
12
+ erniebot.access_token = os.environ["ErnieAccess"]
13
+
14
+ from .BaseLLM import BaseLLM
15
+
16
+ class ErnieGPT(BaseLLM):
17
+
18
+ def __init__(self,model="ernie-bot", ernie_trick = True ):
19
+ super(ErnieGPT,self).__init__()
20
+ self.model = model
21
+ if model not in ["ernie-bot", "ernie-bot-turbo", "ernie-vilg-v2", "ernie-text-embedding", "ernie-bot-8k", "ernie-bot-4"]:
22
+ raise Exception("Unknown Ernie model")
23
+ # SparkApi.answer =""
24
+ self.messages = []
25
+
26
+ self.ernie_trick = ernie_trick
27
+
28
+
29
+ def initialize_message(self):
30
+ self.messages = []
31
+
32
+ def ai_message(self, payload):
33
+ if len(self.messages) == 0:
34
+ self.user_message("请根据我的要求进行角色扮演:")
35
+ elif len(self.messages) % 2 == 1:
36
+ self.messages.append({"role":"assistant","content":payload})
37
+ elif len(self.messages)% 2 == 0:
38
+ self.messages[-1]["content"] += "\n"+ payload
39
+
40
+ def system_message(self, payload):
41
+
42
+ self.messages.append({"role":"user","content":payload})
43
+
44
+
45
+ def user_message(self, payload):
46
+ if len(self.messages) % 2 == 0:
47
+ self.messages.append({"role":"user","content":payload})
48
+ # self.messages[-1]["content"] +=
49
+ elif len(self.messages)% 2 == 1:
50
+ self.messages[-1]["content"] += "\n"+ payload
51
+
52
+ def get_response(self):
53
+ # question = checklen(getText("user",Input))
54
+ chat_messages = copy.deepcopy(self.messages)
55
+
56
+ lines = chat_messages[-1]["content"].split('\n')
57
+
58
+ if self.ernie_trick:
59
+ lines.insert(-1, '请请模仿上述经典桥段进行回复\n')
60
+
61
+ chat_messages[-1]["content"] = '\n'.join(lines)
62
+
63
+ # chat_messages[-1]["content"] = "请请模仿上述经典桥段进行回复\n" + chat_messages[-1]["content"]
64
+ response = erniebot.ChatCompletion.create(model=self.model, messages=chat_messages)
65
+ # message_json = [{"role": "user", "content": self.messages}]
66
+ # SparkApi.answer =""
67
+ # SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,message_json)
68
+ return response["result"]
69
+
70
+ def print_prompt(self):
71
+ for message in self.messages:
72
+ print(f"{message['role']}: {message['content']}")
ChatHaruhi/GLMPro.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .BaseLLM import BaseLLM
2
+ import os
3
+
4
+ zhipu_api = os.environ['ZHIPU_API']
5
+
6
+ import zhipuai
7
+ import time
8
+
9
+ class GLMPro( BaseLLM ):
10
+ def __init__(self, model="chatglm_pro", verbose = False ):
11
+ super(GLMPro,self).__init__()
12
+
13
+ zhipuai.api_key = zhipu_api
14
+
15
+ self.verbose = verbose
16
+
17
+ self.model_name = model
18
+
19
+ self.prompts = []
20
+
21
+ if self.verbose == True:
22
+ print('model name, ', self.model_name )
23
+ if len( zhipu_api ) > 8:
24
+ print( 'found apikey ', zhipu_api[:4], '****', zhipu_api[-4:] )
25
+ else:
26
+ print( 'found apikey but too short, ' )
27
+
28
+
29
+ def initialize_message(self):
30
+ self.prompts = []
31
+
32
+ def ai_message(self, payload):
33
+ self.prompts.append({"role":"assistant","content":payload})
34
+
35
+ def system_message(self, payload):
36
+ self.prompts.append({"role":"user","content":payload})
37
+
38
+ def user_message(self, payload):
39
+ self.prompts.append({"role":"user","content":payload})
40
+
41
+ def get_response(self):
42
+ zhipuai.api_key = zhipu_api
43
+ max_test_name = 5
44
+ sleep_interval = 3
45
+
46
+ request_id = None
47
+
48
+
49
+
50
+ # try submit asychonize request until success
51
+ for test_time in range( max_test_name ):
52
+ response = zhipuai.model_api.async_invoke(
53
+ model = self.model_name,
54
+ prompt = self.prompts,
55
+ temperature = 0)
56
+ if response['success'] == True:
57
+ request_id = response['data']['task_id']
58
+
59
+ if self.verbose == True:
60
+ print('submit request, id = ', request_id )
61
+ break
62
+ else:
63
+ print('submit GLM request failed, retrying...')
64
+ time.sleep( sleep_interval )
65
+
66
+ if request_id:
67
+ # try get response until success
68
+ for test_time in range( 2 * max_test_name ):
69
+ result = zhipuai.model_api.query_async_invoke_result( request_id )
70
+ if result['code'] == 200 and result['data']['task_status'] == 'SUCCESS':
71
+
72
+ if self.verbose == True:
73
+ print('get GLM response success' )
74
+
75
+ choices = result['data']['choices']
76
+ if len( choices ) > 0:
77
+ return choices[-1]['content'].strip("\"'")
78
+
79
+ # other wise means failed
80
+ if self.verbose == True:
81
+ print('get GLM response failed, retrying...')
82
+ # sleep for 1 second
83
+ time.sleep( sleep_interval )
84
+ else:
85
+ print('submit GLM request failed, please check your api key and model name')
86
+ return ''
87
+
88
+ def print_prompt(self):
89
+ for message in self.prompts:
90
+ print(f"{message['role']}: {message['content']}")
ChatHaruhi/LangChainGPT.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
2
+ #
3
+ # ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
4
+ #
5
+ # chengli.thu@gmail.com, mws22@mails.tsinghua.edu.cn
6
+ #
7
+ # Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
8
+ # Weishi Mi is pursuing a job or a PhD position, which who will be available next year
9
+ #
10
+ # homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
11
+ #
12
+ # ChatHaruhi is a chatbot that can revive anime characters in reality.
13
+ # the 2.0 version was built by Cheng Li and Weishi Mi.
14
+ #
15
+ # Please cite our paper if you use this code for research:
16
+ #
17
+ # @misc{li2023chatharuhi,
18
+ # title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
19
+ # author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
20
+ # year={2023},
21
+ # eprint={2308.09597},
22
+ # archivePrefix={arXiv},
23
+ # primaryClass={cs.CL}
24
+ # }
25
+
26
+
27
+ from langchain.chat_models import ChatOpenAI
28
+ from langchain.prompts.chat import (
29
+ ChatPromptTemplate,
30
+ SystemMessagePromptTemplate,
31
+ AIMessagePromptTemplate,
32
+ HumanMessagePromptTemplate,
33
+ )
34
+ from langchain.schema import (
35
+ AIMessage,
36
+ HumanMessage,
37
+ SystemMessage
38
+ )
39
+ from .BaseLLM import BaseLLM
40
+
41
+ import os
42
+ from dotenv import load_dotenv
43
+
44
+
45
+ class LangChainGPT(BaseLLM):
46
+
47
+ def __init__(self, model="gpt-3.5-turbo"):
48
+ super(LangChainGPT, self).__init__()
49
+ self.model = model
50
+ if "OPENAI_API_BASE" in os.environ:
51
+ load_dotenv()
52
+ api_base = os.environ["OPENAI_API_BASE"]
53
+ api_key = os.environ["OPENAI_API_KEY"]
54
+ self.chat = ChatOpenAI(model=self.model, openai_api_base=api_base)
55
+ else:
56
+ self.chat = ChatOpenAI(model=self.model)
57
+ # add api_base
58
+ self.messages = []
59
+
60
+ def initialize_message(self):
61
+ self.messages = []
62
+
63
+ def ai_message(self, payload):
64
+ self.messages.append(AIMessage(content=payload))
65
+
66
+ def system_message(self, payload):
67
+ self.messages.append(SystemMessage(content=payload))
68
+
69
+ def user_message(self, payload):
70
+ self.messages.append(HumanMessage(content=payload))
71
+
72
+ def get_response(self):
73
+ response = self.chat(self.messages)
74
+ return response.content
75
+
76
+ def print_prompt(self):
77
+ for message in self.messages:
78
+ print(message)
ChatHaruhi/PrintLLM.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
2
+ #
3
+ # ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
4
+ #
5
+ # chengli.thu@gmail.com, mws22@mails.tsinghua.edu.cn
6
+ #
7
+ # Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
8
+ # Weishi Mi is pursuing a job or a PhD position, which who will be available next year
9
+ #
10
+ # homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
11
+ #
12
+ # ChatHaruhi is a chatbot that can revive anime characters in reality.
13
+ # the 2.0 version was built by Cheng Li and Weishi Mi.
14
+ #
15
+ # Please cite our paper if you use this code for research:
16
+ #
17
+ # @misc{li2023chatharuhi,
18
+ # title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
19
+ # author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
20
+ # year={2023},
21
+ # eprint={2308.09597},
22
+ # archivePrefix={arXiv},
23
+ # primaryClass={cs.CL}
24
+ # }
25
+ #
26
+ # This PrintLLM.py is for debuging with any real-runing LLM
27
+ # so you can see full prompt and copy it into GPT or Claude to debug
28
+ #
29
+
30
+ from .BaseLLM import BaseLLM
31
+
32
+ class PrintLLM(BaseLLM):
33
+
34
+ def __init__(self ):
35
+ self.messages = []
36
+ self.messages.append("Noticing: This is a print LLM for debug.")
37
+ self.messages.append("But you can also copy the prompt into GPT or Claude to debugging")
38
+
39
+ def initialize_message(self):
40
+ self.messages = []
41
+ self.messages.append("Noticing: This is a print LLM for debug.")
42
+ self.messages.append("But you can also copy the prompt into GPT or Claude to debugging")
43
+
44
+ def ai_message(self, payload):
45
+ self.messages.append("AI: \n" + payload)
46
+
47
+ def system_message(self, payload):
48
+ self.messages.append("System: \n" + payload)
49
+
50
+ def user_message(self, payload):
51
+ self.messages.append("User: \n" + payload)
52
+
53
+ def get_response(self):
54
+ for message in self.messages:
55
+ print(message)
56
+ response = input("Please input your response: ")
57
+ return response
58
+
59
+ def print_prompt(self):
60
+ for message in self.messages:
61
+ print(message)
ChatHaruhi/Qwen118k2GPT.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .BaseLLM import BaseLLM
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from peft import PeftModel
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from transformers.generation import GenerationConfig
7
+
8
+ tokenizer_qwen = None
9
+ model_qwen = None
10
+
11
+
12
+
13
+ def initialize_Qwen2LORA(model):
14
+ global model_qwen, tokenizer_qwen
15
+
16
+ if model_qwen is None:
17
+ model_qwen = AutoModelForCausalLM.from_pretrained(
18
+ model,
19
+ # torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ trust_remote_code=True
22
+ ).half()
23
+ model_qwen = model_qwen.eval()
24
+ # model_qwen = PeftModel.from_pretrained(
25
+ # model_qwen,
26
+ # "silk-road/Chat-Haruhi-Fusion_B"
27
+ # )
28
+
29
+ if tokenizer_qwen is None:
30
+ tokenizer_qwen = AutoTokenizer.from_pretrained(
31
+ model,
32
+ # use_fast=True,
33
+ trust_remote_code=True
34
+ )
35
+
36
+ return model_qwen, tokenizer_qwen
37
+
38
+ def Qwen_tokenizer(text):
39
+ return len(tokenizer_qwen.encode(text))
40
+
41
+ class Qwen118k2GPT(BaseLLM):
42
+ def __init__(self, model):
43
+ super(Qwen118k2GPT, self).__init__()
44
+ global model_qwen, tokenizer_qwen
45
+ if model == "Qwen/Qwen-1_8B-Chat":
46
+ tokenizer_qwen = AutoTokenizer.from_pretrained(
47
+ "Qwen/Qwen-1_8B-Chat",
48
+ trust_remote_code=True
49
+ )
50
+ model_qwen = AutoModelForCausalLM.from_pretrained(
51
+ "Qwen/Qwen-1_8B-Chat",
52
+ device_map="auto",
53
+ trust_remote_code=True
54
+ ).eval()
55
+ self.model = model_qwen
56
+ self.tokenizer = tokenizer_qwen
57
+ elif "silk-road/" in model :
58
+ self.model, self.tokenizer = initialize_Qwen2LORA(model)
59
+ else:
60
+ raise Exception("Unknown Qwen model")
61
+ self.messages = ""
62
+
63
+ def initialize_message(self):
64
+ self.messages = ""
65
+
66
+ def ai_message(self, payload):
67
+ self.messages = "AI: " + self.messages + "\n " + payload
68
+
69
+ def system_message(self, payload):
70
+ self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload
71
+
72
+ def user_message(self, payload):
73
+ self.messages = "User: " + self.messages + "\n " + payload
74
+
75
+ def get_response(self):
76
+ with torch.no_grad():
77
+ response, history = self.model.chat(self.tokenizer, self.messages, history=[])
78
+ # print(response)
79
+ return response
80
+
81
+ def print_prompt(self):
82
+ print(type(self.messages))
83
+ print(self.messages)
84
+
85
+
ChatHaruhi/SparkApi.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 由讯飞提供的websocket接口,用于与星火机器人进行交互
2
+
3
+ import _thread as thread
4
+ import base64
5
+ import datetime
6
+ import hashlib
7
+ import hmac
8
+ import json
9
+ from urllib.parse import urlparse
10
+ import ssl
11
+ from datetime import datetime
12
+ from time import mktime
13
+ from urllib.parse import urlencode
14
+ from wsgiref.handlers import format_date_time
15
+
16
+ import websocket # 使用websocket_client
17
+ answer = ""
18
+
19
+ class Ws_Param(object):
20
+ # 初始化
21
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
22
+ self.APPID = APPID
23
+ self.APIKey = APIKey
24
+ self.APISecret = APISecret
25
+ self.host = urlparse(Spark_url).netloc
26
+ self.path = urlparse(Spark_url).path
27
+ self.Spark_url = Spark_url
28
+
29
+ # 生成url
30
+ def create_url(self):
31
+ # 生成RFC1123格式的时间戳
32
+ now = datetime.now()
33
+ date = format_date_time(mktime(now.timetuple()))
34
+
35
+ # 拼接字符串
36
+ signature_origin = "host: " + self.host + "\n"
37
+ signature_origin += "date: " + date + "\n"
38
+ signature_origin += "GET " + self.path + " HTTP/1.1"
39
+
40
+ # 进行hmac-sha256进行加密
41
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
42
+ digestmod=hashlib.sha256).digest()
43
+
44
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
45
+
46
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
47
+
48
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
49
+
50
+ # 将请求的鉴权参数组合为字典
51
+ v = {
52
+ "authorization": authorization,
53
+ "date": date,
54
+ "host": self.host
55
+ }
56
+ # 拼接鉴权参数,生成url
57
+ url = self.Spark_url + '?' + urlencode(v)
58
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
59
+ return url
60
+
61
+
62
+ # 收到websocket错误的处理
63
+ def on_error(ws, error):
64
+ print("### error:", error)
65
+
66
+
67
+ # 收到websocket关闭的处理
68
+ def on_close(ws,one,two):
69
+ print(" ")
70
+
71
+
72
+ # 收到websocket连接建立的处理
73
+ def on_open(ws):
74
+ thread.start_new_thread(run, (ws,))
75
+
76
+
77
+ def run(ws, *args):
78
+ data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
79
+ ws.send(data)
80
+
81
+
82
+ # 收到websocket消息的处理
83
+ def on_message(ws, message):
84
+ # print(message)
85
+ data = json.loads(message)
86
+ code = data['header']['code']
87
+ if code != 0:
88
+ print(f'请求错误: {code}, {data}')
89
+ ws.close()
90
+ else:
91
+ choices = data["payload"]["choices"]
92
+ status = choices["status"]
93
+ content = choices["text"][0]["content"]
94
+ # print(content,end ="")
95
+ global answer
96
+ answer += content
97
+ # print(1)
98
+ if status == 2:
99
+ ws.close()
100
+
101
+
102
+ def gen_params(appid, domain,question):
103
+ """
104
+ 通过appid和用户的提问来生成请参数
105
+ """
106
+ data = {
107
+ "header": {
108
+ "app_id": appid,
109
+ "uid": "1234"
110
+ },
111
+ "parameter": {
112
+ "chat": {
113
+ "domain": domain,
114
+ "random_threshold": 0.5,
115
+ "max_tokens": 2048,
116
+ "auditing": "default"
117
+ }
118
+ },
119
+ "payload": {
120
+ "message": {
121
+ "text": question
122
+ }
123
+ }
124
+ }
125
+ return data
126
+
127
+
128
+ def main(appid, api_key, api_secret, Spark_url,domain, question):
129
+ # print("星火:")
130
+ wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
131
+ websocket.enableTrace(False)
132
+ wsUrl = wsParam.create_url()
133
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
134
+ ws.appid = appid
135
+ ws.question = question
136
+ ws.domain = domain
137
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
138
+
139
+
ChatHaruhi/SparkGPT.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SparkGPT.py
2
+ from . import SparkApi
3
+ #以下密钥信息从os环境获取
4
+ import os
5
+
6
+ appid = os.environ['APPID']
7
+ api_secret = os.environ['APISecret']
8
+ api_key = os.environ['APIKey']
9
+
10
+ from .BaseLLM import BaseLLM
11
+
12
+
13
+
14
+
15
+ class SparkGPT(BaseLLM):
16
+
17
+ def __init__(self, model="Spark3.0"):
18
+ super(SparkGPT,self).__init__()
19
+ self.model_type = model
20
+ self.messages = []
21
+ if self.model_type == "Spark2.0":
22
+ self.domain = "generalv2" # v2.0版本
23
+ self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
24
+ elif self.model_type == "Spark1.5":
25
+ self.domain = "general" # v1.5版本
26
+ self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
27
+ elif self.model_type == "Spark3.0":
28
+ self.domain = "generalv3" # v3.0版本
29
+ self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
30
+ else:
31
+ raise Exception("Unknown Spark model")
32
+
33
+ def initialize_message(self):
34
+ self.messages = []
35
+
36
+ def ai_message(self, payload):
37
+ if len(self.messages) == 0:
38
+ self.user_message("请根据我的要求进行角色扮演:")
39
+ elif len(self.messages) % 2 == 1:
40
+ self.messages.append({"role":"assistant","content":payload})
41
+ elif len(self.messages)% 2 == 0:
42
+ self.messages[-1]["content"] += "\n"+ payload
43
+
44
+ def system_message(self, payload):
45
+
46
+ self.messages.append({"role":"user","content":payload})
47
+
48
+
49
+ def user_message(self, payload):
50
+ if len(self.messages) % 2 == 0:
51
+ self.messages.append({"role":"user","content":payload})
52
+ # self.messages[-1]["content"] +=
53
+ elif len(self.messages)% 2 == 1:
54
+ self.messages[-1]["content"] += "\n"+ payload
55
+
56
+ def get_response(self):
57
+ # question = checklen(getText("user",Input))
58
+ SparkApi.answer =""
59
+ if self.model_type == "Spark2.0":
60
+ self.domain = "generalv2" # v2.0版本
61
+ self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
62
+ elif self.model_type == "Spark1.5":
63
+ self.domain = "general" # v1.5版本
64
+ self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
65
+ elif self.model_type == "Spark3.0":
66
+ self.domain = "generalv3" # v3.0版本
67
+ self.Spark_url = "ws://spark-api.xf-yun.com/v3.1/chat" # v3.0环境的地址
68
+ else:
69
+ raise Exception("Unknown Spark model")
70
+ SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,self.messages)
71
+ return SparkApi.answer
72
+
73
+ def print_prompt(self):
74
+ for message in self.messages:
75
+ print(f"{message['role']}: {message['content']}")
ChatHaruhi/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
2
+ #
3
+ # ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
4
+ #
5
+ # chengli.thu@gmail.com, mws22@mails.tsinghua.edu.cn
6
+ #
7
+ # Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
8
+ # Weishi Mi is pursuing a job or a PhD position, which who will be available next year
9
+ #
10
+ # homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
11
+ #
12
+ # ChatHaruhi is a chatbot that can revive anime characters in reality.
13
+ # the 2.0 version was built by Cheng Li and Weishi Mi.
14
+ #
15
+ # Please cite our paper if you use this code for research:
16
+ #
17
+ # @misc{li2023chatharuhi,
18
+ # title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
19
+ # author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
20
+ # year={2023},
21
+ # eprint={2308.09597},
22
+ # archivePrefix={arXiv},
23
+ # primaryClass={cs.CL}
24
+ # }
25
+
26
+ from .ChatHaruhi import ChatHaruhi
ChatHaruhi/role_name_to_file.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatHaruhi: Reviving Anime Character in Reality via Large Language Model
2
+ #
3
+ # ChatHaruhi 2.0, built by Cheng Li and Weishi Mi
4
+ #
5
+ # chengli.thu@gmail.com, mws22@mails.tsinghua.edu.cn
6
+ #
7
+ # Weishi Mi is a second-year graduate student at Tsinghua University, majoring in computer science.
8
+ # Weishi Mi is pursuing a job or a PhD position, which who will be available next year
9
+ #
10
+ # homepage https://github.com/LC1332/Chat-Haruhi-Suzumiya
11
+ #
12
+ # ChatHaruhi is a chatbot that can revive anime characters in reality.
13
+ # the 2.0 version was built by Cheng Li and Weishi Mi.
14
+ #
15
+ # Please cite our paper if you use this code for research:
16
+ #
17
+ # @misc{li2023chatharuhi,
18
+ # title={ChatHaruhi: Reviving Anime Character in Reality via Large Language Model},
19
+ # author={Cheng Li and Ziang Leng and Chenxi Yan and Junyi Shen and Hao Wang and Weishi MI and Yaying Fei and Xiaoyang Feng and Song Yan and HaoSheng Wang and Linkang Zhan and Yaokai Jia and Pingyu Wu and Haozhen Sun},
20
+ # year={2023},
21
+ # eprint={2308.09597},
22
+ # archivePrefix={arXiv},
23
+ # primaryClass={cs.CL}
24
+ # }
25
+ #
26
+ # if you have attempt to add a new character, please add the role name here
27
+ #
28
+
29
+ role_name_Haruhiu = {'汤师爷': 'tangshiye', 'tangshiye': 'tangshiye', 'Tangshiye': 'tangshiye',
30
+ '慕容复': 'murongfu', 'murongfu': 'murongfu', 'Murongfu': 'murongfu',
31
+ '李云龙': 'liyunlong', 'liyunlong': 'liyunlong', 'Liyunlong': 'liyunlong',
32
+ 'Luna': 'Luna', '王多鱼': 'wangduoyu', 'wangduoyu': 'wangduoyu',
33
+ 'Wangduoyu': 'wangduoyu', 'Ron': 'Ron', '鸠摩智': 'jiumozhi',
34
+ 'jiumozhi': 'jiumozhi', 'Jiumozhi': 'jiumozhi', 'Snape': 'Snape',
35
+ '凉宫春日': 'haruhi', 'haruhi': 'haruhi', 'Haruhi': 'haruhi',
36
+ 'Malfoy': 'Malfoy', '虚竹': 'xuzhu', 'xuzhu': 'xuzhu',
37
+ 'Xuzhu': 'xuzhu', '萧峰': 'xiaofeng',
38
+ 'xiaofeng': 'xiaofeng', 'Xiaofeng': 'xiaofeng', '段誉': 'duanyu',
39
+ 'duanyu': 'duanyu', 'Duanyu': 'duanyu', 'Hermione': 'Hermione',
40
+ 'Dumbledore': 'Dumbledore', '王语嫣': 'wangyuyan', 'wangyuyan':
41
+ 'wangyuyan', 'Wangyuyan': 'wangyuyan', 'Harry': 'Harry',
42
+ 'McGonagall': 'McGonagall', '白展堂': 'baizhantang',
43
+ 'baizhantang': 'baizhantang', 'Baizhantang': 'baizhantang',
44
+ '佟湘玉': 'tongxiangyu', 'tongxiangyu': 'tongxiangyu',
45
+ 'Tongxiangyu': 'tongxiangyu', '郭芙蓉': 'guofurong',
46
+ 'guofurong': 'guofurong', 'Guofurong': 'guofurong', '流浪者': 'wanderer',
47
+ 'wanderer': 'wanderer', 'Wanderer': 'wanderer', '钟离': 'zhongli',
48
+ 'zhongli': 'zhongli', 'Zhongli': 'zhongli', '胡桃': 'hutao', 'hutao': 'hutao',
49
+ 'Hutao': 'hutao', 'Sheldon': 'Sheldon', 'Raj': 'Raj',
50
+ 'Penny': 'Penny', '韦小宝': 'weixiaobao', 'weixiaobao': 'weixiaobao',
51
+ 'Weixiaobao': 'weixiaobao', '乔峰': 'qiaofeng', 'qiaofeng': 'qiaofeng',
52
+ 'Qiaofeng': 'qiaofeng', '神里绫华': 'ayaka', 'ayaka': 'ayaka',
53
+ 'Ayaka': 'ayaka', '雷电将军': 'raidenShogun', 'raidenShogun': 'raidenShogun',
54
+ 'RaidenShogun': 'raidenShogun', '于谦': 'yuqian', 'yuqian': 'yuqian',
55
+ 'Yuqian': 'yuqian', 'Professor McGonagall': 'McGonagall',
56
+ 'Professor Dumbledore': 'Dumbledore'}
57
+
58
+ # input role_name , nick name is also allowed
59
+ # output folder_role_name and url url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip'
60
+ def get_folder_role_name(role_name):
61
+ if role_name in role_name_Haruhiu:
62
+ folder_role_name = role_name_Haruhiu[role_name]
63
+ url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{folder_role_name}.zip'
64
+ return folder_role_name, url
65
+ else:
66
+ print('role_name {} not found, using haruhi as default'.format(role_name))
67
+ return get_folder_role_name('haruhi')
ChatHaruhi/utils.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+
3
+ from openai import OpenAI
4
+
5
+ # client = OpenAI(api_key=<YOUR OPENAI API KEY>)
6
+
7
+ from transformers import AutoModel, AutoTokenizer
8
+ import torch
9
+ import random
10
+
11
+ import tiktoken
12
+ import re
13
+
14
+ import numpy as np
15
+
16
+ import base64
17
+ import struct
18
+
19
+ import os
20
+
21
+ import tqdm
22
+
23
+ import requests
24
+
25
+
26
+
27
+ def get_access_token():
28
+ API_KEY = os.getenv("StoryAudit_API_AK")
29
+ SECRET_KEY = os.getenv("StoryAudit_API_SK")
30
+
31
+ """
32
+ 使用 AK,SK 生成鉴权签名(Access Token)
33
+ :return: access_token,或是None(如果错误)
34
+ """
35
+ url = "https://aip.baidubce.com/oauth/2.0/token"
36
+ params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
37
+ return str(requests.post(url, params=params).json().get("access_token"))
38
+
39
+ '''
40
+ 文本审核接口
41
+ '''
42
+ def text_censor(text):
43
+ request_url = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined"
44
+
45
+ params = {"text":text}
46
+ access_token = get_access_token()
47
+ request_url = request_url + "?access_token=" + access_token
48
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
49
+ response = requests.post(request_url, data=params, headers=headers)
50
+ return response.json()["conclusion"] == "合规"
51
+
52
+ def package_role( system_prompt, texts_path , embedding ):
53
+ datas = []
54
+
55
+ # 暂时只有一种embedding 'luotuo_openai'
56
+ embed_name = 'luotuo_openai'
57
+
58
+ datas.append({ 'text':system_prompt , embed_name:'system_prompt'})
59
+ datas.append({ 'text':'Reserve Config Setting Here' , embed_name:'config'})
60
+
61
+
62
+ # debug_count = 3
63
+
64
+ # for file in os.listdir(texts_path):
65
+
66
+ files = os.listdir(texts_path)
67
+
68
+ for i in tqdm.tqdm(range(len(files))):
69
+ file = files[i]
70
+ # if file name end with txt
71
+ if file.endswith(".txt"):
72
+ file_path = os.path.join(texts_path, file)
73
+ with open(file_path, 'r', encoding='utf-8') as f:
74
+ current_str = f.read()
75
+ current_vec = embedding(current_str)
76
+ encode_vec = float_array_to_base64(current_vec)
77
+ datas.append({ 'text':current_str , embed_name:encode_vec})
78
+
79
+ # debug_count -= 1
80
+ # if debug_count == 0:
81
+ # break
82
+ return datas
83
+
84
+
85
+ import struct
86
+
87
+ def string_to_base64(text):
88
+ byte_array = b''
89
+ for char in text:
90
+ num_bytes = char.encode('utf-8')
91
+ byte_array += num_bytes
92
+
93
+ base64_data = base64.b64encode(byte_array)
94
+ return base64_data.decode('utf-8')
95
+
96
+ def base64_to_string(base64_data):
97
+ byte_array = base64.b64decode(base64_data)
98
+ text = byte_array.decode('utf-8')
99
+ return text
100
+
101
+
102
+ def float_array_to_base64(float_arr):
103
+
104
+ byte_array = b''
105
+
106
+ for f in float_arr:
107
+ # 将每个浮点数打包为4字节
108
+ num_bytes = struct.pack('!f', f)
109
+ byte_array += num_bytes
110
+
111
+ # 将字节数组进行base64编码
112
+ base64_data = base64.b64encode(byte_array)
113
+
114
+ return base64_data.decode('utf-8')
115
+
116
+ def base64_to_float_array(base64_data):
117
+
118
+ byte_array = base64.b64decode(base64_data)
119
+
120
+ float_array = []
121
+
122
+ # 每 4 个字节解析为一个浮点数
123
+ for i in range(0, len(byte_array), 4):
124
+ num = struct.unpack('!f', byte_array[i:i+4])[0]
125
+ float_array.append(num)
126
+
127
+ return float_array
128
+
129
+
130
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
131
+
132
+ _luotuo_model = None
133
+
134
+ _luotuo_model_en = None
135
+ _luotuo_en_tokenizer = None
136
+
137
+ _enc_model = None
138
+
139
+ # ======== add bge_zh mmodel
140
+ # by Cheng Li
141
+ # 这一次我们试图一次性去适配更多的模型
142
+
143
+ _model_pool = {}
144
+ _tokenizer_pool = {}
145
+
146
+ # BAAI/bge-small-zh-v1.5
147
+
148
+ def get_general_embeddings( sentences , model_name = "BAAI/bge-small-zh-v1.5" ):
149
+
150
+ global _model_pool
151
+ global _tokenizer_pool
152
+
153
+ if model_name not in _model_pool:
154
+ from transformers import AutoTokenizer, AutoModel
155
+ _tokenizer_pool[model_name] = AutoTokenizer.from_pretrained(model_name)
156
+ _model_pool[model_name] = AutoModel.from_pretrained(model_name)
157
+
158
+ _model_pool[model_name].eval()
159
+
160
+ # Tokenize sentences
161
+ encoded_input = _tokenizer_pool[model_name](sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512)
162
+
163
+ # Compute token embeddings
164
+ with torch.no_grad():
165
+ model_output = _model_pool[model_name](**encoded_input)
166
+ # Perform pooling. In this case, cls pooling.
167
+ sentence_embeddings = model_output[0][:, 0]
168
+
169
+ # normalize embeddings
170
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
171
+ return sentence_embeddings.cpu().tolist()
172
+
173
+ def get_general_embedding( text_or_texts , model_name = "BAAI/bge-small-zh-v1.5" ):
174
+ if isinstance(text_or_texts, str):
175
+ return get_general_embeddings([text_or_texts], model_name)[0]
176
+ else:
177
+ return get_general_embeddings_safe(text_or_texts, model_name)
178
+
179
+ general_batch_size = 16
180
+
181
+ import math
182
+
183
+ def get_general_embeddings_safe(sentences, model_name = "BAAI/bge-small-zh-v1.5"):
184
+
185
+ embeddings = []
186
+
187
+ num_batches = math.ceil(len(sentences) / general_batch_size)
188
+
189
+ for i in tqdm.tqdm( range(num_batches) ):
190
+ # print("run bge with batch ", i)
191
+ start_index = i * general_batch_size
192
+ end_index = min(len(sentences), start_index + general_batch_size)
193
+ batch = sentences[start_index:end_index]
194
+ embs = get_general_embeddings(batch, model_name)
195
+ embeddings.extend(embs)
196
+
197
+ return embeddings
198
+
199
+ def get_bge_zh_embedding( text_or_texts ):
200
+ return get_general_embedding(text_or_texts, "BAAI/bge-small-zh-v1.5")
201
+
202
+ ## TODO: 重构bge_en部分的代码,复用general的函数
203
+
204
+ # ======== add bge model
205
+ # by Cheng Li
206
+ # for English only right now
207
+
208
+ _bge_model = None
209
+ _bge_tokenizer = None
210
+
211
+ def get_bge_embeddings( sentences ):
212
+ # unsafe ensure batch size by yourself
213
+
214
+ global _bge_model
215
+ global _bge_tokenizer
216
+
217
+ if _bge_model is None:
218
+ from transformers import AutoTokenizer, AutoModel
219
+ _bge_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-en-v1.5')
220
+ _bge_model = AutoModel.from_pretrained('BAAI/bge-small-en-v1.5')
221
+
222
+ _bge_model.eval()
223
+
224
+ # Tokenize sentences
225
+ encoded_input = _bge_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt', max_length = 512)
226
+
227
+ # Compute token embeddings
228
+ with torch.no_grad():
229
+ model_output = _bge_model(**encoded_input)
230
+ # Perform pooling. In this case, cls pooling.
231
+ sentence_embeddings = model_output[0][:, 0]
232
+ # normalize embeddings
233
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
234
+ return sentence_embeddings.cpu().tolist()
235
+
236
+ def get_bge_embedding( text_or_texts ):
237
+ if isinstance(text_or_texts, str):
238
+ return get_bge_embeddings([text_or_texts])[0]
239
+ else:
240
+ return get_bge_embeddings_safe(text_or_texts)
241
+
242
+ bge_batch_size = 32
243
+
244
+ import math
245
+ # from tqdm import tqdm
246
+
247
+ def get_bge_embeddings_safe(sentences):
248
+
249
+ embeddings = []
250
+
251
+ num_batches = math.ceil(len(sentences) / bge_batch_size)
252
+
253
+ for i in tqdm.tqdm( range(num_batches) ):
254
+ # print("run bge with batch ", i)
255
+ start_index = i * bge_batch_size
256
+ end_index = min(len(sentences), start_index + bge_batch_size)
257
+ batch = sentences[start_index:end_index]
258
+ embs = get_bge_embeddings(batch)
259
+ embeddings.extend(embs)
260
+
261
+ return embeddings
262
+
263
+ # === add bge model
264
+
265
+ def tiktokenizer( text ):
266
+ global _enc_model
267
+
268
+ if _enc_model is None:
269
+ _enc_model = tiktoken.get_encoding("cl100k_base")
270
+
271
+ return len(_enc_model.encode(text))
272
+
273
+ def response_postprocess(text,dialogue_bra_token = '「',dialogue_ket_token = '」'):
274
+ lines = text.split('\n')
275
+ new_lines = ""
276
+
277
+ first_name = None
278
+
279
+ for line in lines:
280
+ line = line.strip(" ")
281
+ match = re.match(r'^(.*?)[::]' + dialogue_bra_token + r"(.*?)" + dialogue_ket_token + r"$", line)
282
+
283
+
284
+ if match:
285
+ curr_name = match.group(1)
286
+ # print(curr_name)
287
+ if first_name is None:
288
+ first_name = curr_name
289
+ new_lines += (match.group(2))
290
+ else:
291
+ if curr_name != first_name:
292
+ return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
293
+ else:
294
+ new_lines += (match.group(2))
295
+
296
+ else:
297
+ if first_name == None:
298
+ return text
299
+ else:
300
+ return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
301
+ return first_name + ":" + dialogue_bra_token + new_lines + dialogue_ket_token
302
+
303
+ def download_models():
304
+ print("正在下载Luotuo-Bert")
305
+ # Import our models. The package will take care of downloading the models automatically
306
+ model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False,
307
+ init_embeddings_model=None)
308
+ model = AutoModel.from_pretrained("silk-road/luotuo-bert-medium", trust_remote_code=True, model_args=model_args).to(
309
+ device)
310
+ print("Luotuo-Bert下载完毕")
311
+ return model
312
+
313
+ def get_luotuo_model():
314
+ global _luotuo_model
315
+ if _luotuo_model is None:
316
+ _luotuo_model = download_models()
317
+ return _luotuo_model
318
+
319
+
320
+ def luotuo_embedding(model, texts):
321
+ # Tokenize the texts_source
322
+ tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-medium")
323
+ inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
324
+ inputs = inputs.to(device)
325
+ # Extract the embeddings
326
+ # Get the embeddings
327
+ with torch.no_grad():
328
+ embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
329
+ return embeddings
330
+
331
+ def luotuo_en_embedding( texts ):
332
+ # this function implemented by Cheng
333
+ global _luotuo_model_en
334
+ global _luotuo_en_tokenizer
335
+
336
+ if _luotuo_model_en is None:
337
+ _luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
338
+ _luotuo_model_en = AutoModel.from_pretrained("silk-road/luotuo-bert-en").to(device)
339
+
340
+ if _luotuo_en_tokenizer is None:
341
+ _luotuo_en_tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert-en")
342
+
343
+ inputs = _luotuo_en_tokenizer(texts, padding=True, truncation=False, return_tensors="pt")
344
+ inputs = inputs.to(device)
345
+
346
+ with torch.no_grad():
347
+ embeddings = _luotuo_model_en(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output
348
+
349
+ return embeddings
350
+
351
+
352
+ def get_embedding_for_chinese(model, texts):
353
+ model = model.to(device)
354
+ # str or strList
355
+ texts = texts if isinstance(texts, list) else [texts]
356
+ # 截断
357
+ for i in range(len(texts)):
358
+ if len(texts[i]) > 510:
359
+ texts[i] = texts[i][:510]
360
+ if len(texts) >= 64:
361
+ embeddings = []
362
+ chunk_size = 64
363
+ for i in range(0, len(texts), chunk_size):
364
+ embeddings.append(luotuo_embedding(model, texts[i: i + chunk_size]))
365
+ return torch.cat(embeddings, dim=0)
366
+ else:
367
+ return luotuo_embedding(model, texts)
368
+
369
+
370
+ def is_chinese_or_english(text):
371
+ # no longer use online openai api
372
+ return "chinese"
373
+
374
+ text = list(text)
375
+ is_chinese, is_english = 0, 0
376
+
377
+ for char in text:
378
+ # 判断字符的Unicode值是否在中文字符的Unicode范围内
379
+ if '\u4e00' <= char <= '\u9fa5':
380
+ is_chinese += 4
381
+ # 判断字符是否为英文字符(包括大小写字母和常见标点符号)
382
+ elif ('\u0041' <= char <= '\u005a') or ('\u0061' <= char <= '\u007a'):
383
+ is_english += 1
384
+ if is_chinese >= is_english:
385
+ return "chinese"
386
+ else:
387
+ return "english"
388
+
389
+
390
+ def get_embedding_openai(text, model="text-embedding-ada-002"):
391
+ text = text.replace("\n", " ")
392
+ return client.embeddings.create(input = [text], model=model).data[0].embedding
393
+
394
+ def get_embedding_for_english(text, model="text-embedding-ada-002"):
395
+ text = text.replace("\n", " ")
396
+ return client.embeddings.create(input = [text], model=model).data[0].embedding
397
+
398
+ import os
399
+
400
+ def luotuo_openai_embedding(texts, is_chinese= None ):
401
+ """
402
+ when input is chinese, use luotuo_embedding
403
+ when input is english, use openai_embedding
404
+ texts can be a list or a string
405
+ when texts is a list, return a list of embeddings, using batch inference
406
+ when texts is a string, return a single embedding
407
+ """
408
+
409
+ openai_key = os.environ.get("OPENAI_API_KEY")
410
+
411
+ if isinstance(texts, list):
412
+ index = random.randint(0, len(texts) - 1)
413
+ if openai_key is None or is_chinese_or_english(texts[index]) == "chinese":
414
+ return [embed.cpu().tolist() for embed in get_embedding_for_chinese(get_luotuo_model(), texts)]
415
+ else:
416
+ return [get_embedding_for_english(text) for text in texts]
417
+ else:
418
+ if openai_key is None or is_chinese_or_english(texts) == "chinese":
419
+ return get_embedding_for_chinese(get_luotuo_model(), texts)[0].cpu().tolist()
420
+ else:
421
+ return get_embedding_for_english(texts)
422
+
423
+
424
+ # compute cosine similarity between two vector
425
+ def get_cosine_similarity( v1, v2):
426
+ v1 = torch.tensor(v1).to(device)
427
+ v2 = torch.tensor(v2).to(device)
428
+ return torch.cosine_similarity(v1, v2, dim=0).item()
429
+
430
+
431
+