李新豪 commited on
Commit
14b8b1d
1 Parent(s): e535922
Files changed (6) hide show
  1. get_dataset.py +0 -68
  2. logger.py +0 -60
  3. prompt_concat.py +0 -170
  4. retrieve_dialog.py +0 -135
  5. src/retrieve_dialog.py +3 -2
  6. utils.py +0 -59
get_dataset.py DELETED
@@ -1,68 +0,0 @@
1
- # coding=utf-8
2
- import sys
3
- sys.path.append("../")
4
-
5
- from collections import defaultdict
6
- from .utils import is_float, load_txt
7
-
8
- import random
9
-
10
- random.seed(1234)
11
-
12
-
13
- class CreateDataset:
14
- def __init__(self, max_input_len=1500):
15
- self.prompt = load_txt("../prompt/dataset_character.txt")
16
- self.max_input_len = max_input_len # 小于(seq-length)-(max-gen-length)
17
- self.example_split_flag = f"\n{'-' * 20}\n"
18
-
19
- self.dataset = defaultdict(list)
20
- self.manual_dataset = []
21
-
22
- @staticmethod
23
- def choose_examples(similar_examples,
24
- max_length,
25
- train_flag=False,
26
- dialog=None,
27
- example_split_flag=f"\n{'-' * 20}\n"):
28
- if isinstance(similar_examples, str):
29
- new_similar_examples = [x.strip() for x in similar_examples.split(example_split_flag)]
30
- else:
31
- # 去重
32
- new_similar_examples = []
33
- for example in similar_examples:
34
- if (isinstance(example, list) or isinstance(example, tuple)) and len(example) == 2 and is_float(
35
- example[0]):
36
- # 包含score
37
- example = example[1]
38
-
39
- try:
40
- example = "\n".join(example).strip()
41
- except TypeError:
42
- raise TypeError(f"example: {example}")
43
- if train_flag and dialog and (example in dialog or dialog in example):
44
- continue
45
-
46
- # example去重
47
- if train_flag:
48
- # 部分相似也去掉
49
- flag = False
50
- for n_example in new_similar_examples:
51
- if example in n_example or n_example in example:
52
- flag = True
53
- break
54
- if not flag:
55
- new_similar_examples.append(example)
56
- else:
57
- if example not in new_similar_examples:
58
- new_similar_examples.append(example)
59
-
60
- results = []
61
- total_length = 0
62
- for example in new_similar_examples:
63
- total_length += len(example) if not total_length else len(example_split_flag) + len(example)
64
- if total_length > max_length:
65
- break
66
- results.append(example)
67
- results = example_split_flag.join(results).strip()
68
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
logger.py DELETED
@@ -1,60 +0,0 @@
1
- # coding=utf-8
2
- from logging.handlers import TimedRotatingFileHandler
3
-
4
- import os
5
- import sys
6
- import logging
7
-
8
-
9
- class LoggerFactory:
10
-
11
- @staticmethod
12
- def create_logger(name=None, level=logging.INFO):
13
- """create a logger
14
-
15
- Args:
16
- name (str): name of the logger
17
- level: level of logger
18
-
19
- Raises:
20
- ValueError is name is None
21
- """
22
-
23
- if name is None:
24
- raise ValueError("name for logger cannot be None")
25
-
26
- formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] "
27
- "[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
28
-
29
- logger_ = logging.getLogger(name)
30
- logger_.setLevel(level)
31
- logger_.propagate = False
32
- ch = logging.StreamHandler(stream=sys.stdout)
33
- ch.setLevel(level)
34
- ch.setFormatter(formatter)
35
- logger_.addHandler(ch)
36
- return logger_
37
-
38
- @staticmethod
39
- def create_logger_with_file(log_file_path: str = None, logger_level=logging.INFO):
40
- logger_inner = logging.getLogger()
41
- logger_inner.setLevel(logger_level)
42
- logger_inner.propagate = True
43
-
44
- formatter = logging.Formatter(fmt="[%(asctime)s] [%(filename)s:%(lineno)s - %(levelname)s] %(message)s",
45
- datefmt="%Y-%m-%d %H:%M:%S")
46
-
47
- # TimedRotatingFileHandler
48
- if log_file_path:
49
- basedir = os.path.dirname(log_file_path)
50
- if not os.path.isdir(basedir):
51
- os.makedirs(basedir, exist_ok=True)
52
- handler_file = TimedRotatingFileHandler(log_file_path, when="d", interval=1, backupCount=30)
53
- handler_file.setFormatter(formatter)
54
- logger_inner.addHandler(handler_file)
55
-
56
- # StreamHandler
57
- handler_console = logging.StreamHandler()
58
- handler_console.setFormatter(formatter)
59
- logger_inner.addHandler(handler_console)
60
- return logger_inner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompt_concat.py DELETED
@@ -1,170 +0,0 @@
1
- # coding=utf-8
2
- from copy import deepcopy
3
- from .get_dataset import CreateDataset
4
- from .logger import LoggerFactory
5
- from .retrieve_dialog import RetrieveDialog
6
- from .utils import load_json, load_txt, save_to_json
7
-
8
- import logging
9
- import os
10
-
11
- logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
12
-
13
-
14
- class GetManualTestSamples:
15
- def __init__(
16
- self,
17
- role_name,
18
- role_data_path,
19
- save_samples_dir,
20
- save_samples_path=None,
21
- prompt_path="dataset_character.txt",
22
- max_seq_len=4000,
23
- retrieve_num=20,
24
- ):
25
- self.role_name = role_name.strip()
26
- self.role_data = load_json(role_data_path)
27
- self.role_info = self.role_data[0]["role_info"].strip()
28
-
29
- self.prompt = load_txt(prompt_path)
30
- self.prompt = self.prompt.replace("${role_name}", self.role_name)
31
- self.prompt = self.prompt.replace("${role_info}",
32
- f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
33
-
34
- self.retrieve_num = retrieve_num
35
- self.retrieve = RetrieveDialog(role_name=self.role_name,
36
- raw_dialog_list=[d["dialog"] for d in self.role_data],
37
- retrieve_num=retrieve_num)
38
-
39
- self.max_seq_len = max_seq_len
40
- if not save_samples_path:
41
- save_samples_path = f"{self.role_name}.json"
42
- self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)
43
-
44
- def _add_simi_dialog(self, history: list, content_length):
45
- retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
46
- simi_dialogs = deepcopy(retrieve_results)
47
-
48
- if simi_dialogs:
49
- simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
50
- max_length=self.max_seq_len - content_length,
51
- train_flag=False)
52
- logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
53
- return simi_dialogs, retrieve_results
54
-
55
- def get_qa_samples_by_file(self,
56
- questions_path,
57
- user_name="user",
58
- keep_retrieve_results_flag=False
59
- ):
60
- questions = load_txt(questions_path).splitlines()
61
- samples = []
62
- for question in questions:
63
- question = question.replace('\\n', "\n")
64
- query = f"{user_name}:{question}" if ":" not in question else question
65
- content = self.prompt.replace("${dialog}", query)
66
- content = content.replace("${user_name}", user_name).strip()
67
-
68
- history = [query]
69
- simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
70
-
71
- sample = {
72
- "role_name": self.role_name,
73
- "role_info": self.role_info,
74
- "user_name": user_name,
75
- "dialog": history,
76
- "simi_dialogs": simi_dialogs,
77
- }
78
- if keep_retrieve_results_flag and retrieve_results:
79
- sample["retrieve_results"] = retrieve_results
80
- samples.append(sample)
81
- self._save_samples(samples)
82
-
83
- def get_qa_samples_by_query(self,
84
- questions_query,
85
- user_name="user",
86
- keep_retrieve_results_flag=False
87
- ):
88
- question = questions_query
89
- samples = []
90
- question = question.replace('\\n', "\n")
91
- query = f"{user_name}: {question}" if ":" not in question else question
92
- content = self.prompt.replace("${dialog}", query)
93
- content = content.replace("${user_name}", user_name).strip()
94
-
95
- history = [query]
96
- simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
97
-
98
- sample = {
99
- "role_name": self.role_name,
100
- "role_info": self.role_info,
101
- "user_name": user_name,
102
- "dialog": history,
103
- "simi_dialogs": simi_dialogs,
104
- }
105
- if keep_retrieve_results_flag and retrieve_results:
106
- sample["retrieve_results"] = retrieve_results
107
- samples.append(sample)
108
- self._save_samples(samples)
109
-
110
- def _save_samples(self, samples):
111
- data = samples
112
- save_to_json(data, self.save_samples_path)
113
-
114
-
115
- class CreateTestDataset:
116
- def __init__(self,
117
- role_name,
118
- role_samples_path=None,
119
- role_data_path=None,
120
- prompt_path="dataset_character.txt",
121
- max_seq_len=4000):
122
- self.max_seq_len = max_seq_len
123
- self.role_name = role_name
124
-
125
- self.prompt = load_txt(prompt_path)
126
- self.prompt = self.prompt.replace("${role_name}", role_name).strip()
127
-
128
- if not role_data_path:
129
- print("need role_data_path, check please!")
130
- self.default_simi_dialogs = None
131
- if os.path.exists(role_data_path):
132
- data = load_json(role_data_path)
133
- role_info = data[0]["role_info"]
134
- else:
135
- raise ValueError(f"{self.role_name} didn't find role_info.")
136
- self.role_info = role_info
137
- self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
138
-
139
- if role_samples_path:
140
- self.role_samples_path = role_samples_path
141
- else:
142
- print("check role_samples_path please!")
143
-
144
- def load_samples(self):
145
- samples = load_json(self.role_samples_path)
146
- results = []
147
- for sample in samples:
148
- input_text = self.prompt
149
-
150
- simi_dialogs = sample.get("simi_dialogs", None)
151
- if not simi_dialogs:
152
- simi_dialogs = self.default_simi_dialogs
153
- if not simi_dialogs:
154
- raise ValueError(f"didn't find simi_dialogs.")
155
- simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
156
- max_length=self.max_seq_len - len(input_text),
157
- train_flag=False)
158
-
159
- input_text = input_text.replace("${simi_dialog}", simi_dialogs)
160
- user_name = sample.get("user_name", "user")
161
- input_text = input_text.replace("${user_name}", user_name)
162
-
163
- dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
164
- input_text = input_text.replace("${dialog}", dialog)
165
-
166
- assert len(input_text) < self.max_seq_len
167
- results.append({
168
- "input_text": input_text,
169
- })
170
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retrieve_dialog.py DELETED
@@ -1,135 +0,0 @@
1
- # coding=utf-8
2
- from sentence_transformers import SentenceTransformer
3
- from .utils import load_json
4
-
5
- import faiss
6
- import logging
7
- import os
8
- import re
9
- import torch
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class RetrieveDialog:
15
- def __init__(self,
16
- role_name,
17
- raw_dialog_list: list = None,
18
- retrieve_num=20,
19
- min_mean_role_utter_length=10):
20
- if torch.cuda.is_available():
21
- gpu_id = 0
22
- torch.cuda.set_device(gpu_id)
23
-
24
- assert raw_dialog_list
25
-
26
- self.role_name = role_name
27
- self.min_mean_role_utter_length = min_mean_role_utter_length
28
- self.retrieve_num = retrieve_num
29
-
30
- # config = load_json("config/config.json")
31
- # local_dir = config["bge_local_path"]
32
- local_dir = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
33
-
34
- if not os.path.exists(local_dir):
35
- print("Please download bge-large-zh-v1.5 first!")
36
- self.emb_model = SentenceTransformer(local_dir)
37
-
38
- self.dialogs, self.context_index = self._get_emb_base_by_list(raw_dialog_list)
39
-
40
- logger.info(f"dialog db num: {len(self.dialogs)}")
41
- logger.info(f"RetrieveDialog init success.")
42
-
43
- @staticmethod
44
- def dialog_preprocess(dialog: list, role_name):
45
- dialog_new = []
46
- # 把人名替换掉,减少对检索的影响
47
- user_names = []
48
- role_utter_length = []
49
- for num in range(len(dialog)):
50
- utter = dialog[num]
51
- try:
52
- user_name, utter_txt = re.split('[::]', utter, maxsplit=1)
53
- except ValueError as e:
54
- logging.error(f"utter:{utter} can't find user_name.")
55
- return None, None
56
-
57
- if user_name != role_name:
58
- if user_name not in user_names:
59
- user_names.append(user_name)
60
- index = user_names.index(user_name)
61
- utter = utter.replace(user_name, f"user{index}", 1)
62
- else:
63
- role_utter_length.append(len(utter_txt))
64
- dialog_new.append(utter)
65
- return dialog_new, user_names, role_utter_length
66
-
67
- def _get_emb_base_by_list(self, raw_dialog_list):
68
- logger.info(f"raw dialog db num: {len(raw_dialog_list)}")
69
- new_raw_dialog_list = []
70
- context_list = []
71
-
72
- # 为了兼容因为句长把所有对话都过滤掉的情况
73
- new_raw_dialog_list_total = []
74
- context_list_total = []
75
- for raw_dialog in raw_dialog_list:
76
- if not raw_dialog:
77
- continue
78
-
79
- end = 0
80
- for x in raw_dialog[::-1]:
81
- if x.startswith(self.role_name):
82
- break
83
- end += 1
84
-
85
- raw_dialog = raw_dialog[:len(raw_dialog) - end]
86
- new_dialog, user_names, role_utter_length = self.dialog_preprocess(raw_dialog, self.role_name)
87
- if not new_dialog or not role_utter_length:
88
- continue
89
-
90
- if raw_dialog in new_raw_dialog_list_total:
91
- continue
92
-
93
- # 获得embedding时,不需要最后一句答案
94
- context = "\n".join(new_dialog) if len(new_dialog) < 2 else "\n".join(new_dialog[:-1])
95
-
96
- new_raw_dialog_list_total.append(raw_dialog)
97
- context_list_total.append(context)
98
-
99
- # 句长过滤
100
- role_length_mean = sum(role_utter_length) / len(role_utter_length)
101
- if role_length_mean < self.min_mean_role_utter_length:
102
- continue
103
- new_raw_dialog_list.append(raw_dialog)
104
- context_list.append(context)
105
-
106
- assert len(new_raw_dialog_list) == len(context_list)
107
- logger.debug(f"new_raw_dialog num: {len(new_raw_dialog_list)}")
108
-
109
- # 兼容样本过少的情况
110
- if len(new_raw_dialog_list) < self.retrieve_num:
111
- new_raw_dialog_list = new_raw_dialog_list_total
112
- context_list = context_list_total
113
-
114
- # 对话向量库
115
- context_vectors = self.emb_model.encode(context_list, normalize_embeddings=True)
116
- context_index = faiss.IndexFlatL2(context_vectors.shape[1])
117
- context_index.add(context_vectors)
118
-
119
- return new_raw_dialog_list, context_index
120
-
121
- def get_retrieve_res(self, dialog: list, retrieve_num: int):
122
- logger.debug(f"dialog: {dialog}")
123
-
124
- # 同样去掉user name影响
125
- dialog, _, _ = self.dialog_preprocess(dialog, self.role_name)
126
- dialog_vector = self.emb_model.encode(["\n".join(dialog)], normalize_embeddings=True)
127
-
128
- simi_dialog_distance, simi_dialog_index = self.context_index.search(
129
- dialog_vector, min(retrieve_num, len(self.dialogs)))
130
- simi_dialog_results = [
131
- (str(simi_dialog_distance[0][num]), self.dialogs[index]) for num, index in enumerate(simi_dialog_index[0])
132
- ]
133
- logger.debug(f"dialog retrieve res: {simi_dialog_results}")
134
-
135
- return simi_dialog_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/retrieve_dialog.py CHANGED
@@ -27,8 +27,9 @@ class RetrieveDialog:
27
  self.min_mean_role_utter_length = min_mean_role_utter_length
28
  self.retrieve_num = retrieve_num
29
 
30
- config = load_json("config/config.json")
31
- local_dir = config["bge_local_path"]
 
32
 
33
  if not os.path.exists(local_dir):
34
  print("Please download bge-large-zh-v1.5 first!")
 
27
  self.min_mean_role_utter_length = min_mean_role_utter_length
28
  self.retrieve_num = retrieve_num
29
 
30
+ # config = load_json("config/config.json")
31
+ # local_dir = config["bge_local_path"]
32
+ local_dir = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
33
 
34
  if not os.path.exists(local_dir):
35
  print("Please download bge-large-zh-v1.5 first!")
utils.py DELETED
@@ -1,59 +0,0 @@
1
- # coding=utf-8
2
- import csv
3
- import json
4
- import os
5
-
6
-
7
- def read_csv_to_json(file_path, role_name, role_info):
8
- json_list = []
9
-
10
- with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
11
- csv_reader = csv.reader(csvfile)
12
- _ = next(csv_reader)
13
-
14
- for row in csv_reader:
15
- json_object = {
16
- "role_name": role_name,
17
- "role_info": role_info,
18
- "dialog": row[1].split("\n"),
19
- }
20
- json_list.append(json_object)
21
-
22
- return json_list
23
-
24
-
25
- def save_json(json_list, output_path):
26
- with open(output_path, "w", encoding="utf-8") as jsonfile:
27
- json.dump(json_list, jsonfile, ensure_ascii=False, indent=4)
28
-
29
-
30
- def decode_csv_to_json(role_data_path, role_name, role_info, json_output_path):
31
- json_data = read_csv_to_json(role_data_path, role_name, role_info)
32
- save_json(json_data, json_output_path)
33
-
34
-
35
- def load_txt(path):
36
- with open(path, "r", encoding="utf-8", errors="ignore") as file:
37
- text = file.read()
38
- return text
39
-
40
-
41
- def load_json(path):
42
- with open(path, "r", encoding="utf-8") as f:
43
- data = json.load(f)
44
- return data
45
-
46
-
47
- def save_to_json(data, filepath, flag="w"):
48
- if not os.path.exists(os.path.dirname(filepath)):
49
- os.makedirs(os.path.dirname(filepath))
50
- with open(filepath, flag, encoding="utf-8") as f:
51
- f.write(json.dumps(data, ensure_ascii=False, indent=3))
52
-
53
-
54
- def is_float(my_str):
55
- try:
56
- num = float(my_str)
57
- return True
58
- except ValueError:
59
- return False