Spaces:
Runtime error
Runtime error
Upload 6 files
Browse files- __init__.py +6 -0
- get_dataset.py +68 -0
- logger.py +60 -0
- prompt_concat.py +170 -0
- utils.py +59 -0
__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from .get_dataset import *
|
3 |
+
from .logger import *
|
4 |
+
from .prompt_concat import *
|
5 |
+
from .retrieve_dialog import *
|
6 |
+
from .utils import *
|
get_dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
import sys
|
3 |
+
sys.path.append("../")
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
from .utils import is_float, load_txt
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
random.seed(1234)
|
11 |
+
|
12 |
+
|
13 |
+
class CreateDataset:
|
14 |
+
def __init__(self, max_input_len=1500):
|
15 |
+
self.prompt = load_txt("../prompt/dataset_character.txt")
|
16 |
+
self.max_input_len = max_input_len # 小于(seq-length)-(max-gen-length)
|
17 |
+
self.example_split_flag = f"\n{'-' * 20}\n"
|
18 |
+
|
19 |
+
self.dataset = defaultdict(list)
|
20 |
+
self.manual_dataset = []
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def choose_examples(similar_examples,
|
24 |
+
max_length,
|
25 |
+
train_flag=False,
|
26 |
+
dialog=None,
|
27 |
+
example_split_flag=f"\n{'-' * 20}\n"):
|
28 |
+
if isinstance(similar_examples, str):
|
29 |
+
new_similar_examples = [x.strip() for x in similar_examples.split(example_split_flag)]
|
30 |
+
else:
|
31 |
+
# 去重
|
32 |
+
new_similar_examples = []
|
33 |
+
for example in similar_examples:
|
34 |
+
if (isinstance(example, list) or isinstance(example, tuple)) and len(example) == 2 and is_float(
|
35 |
+
example[0]):
|
36 |
+
# 包含score
|
37 |
+
example = example[1]
|
38 |
+
|
39 |
+
try:
|
40 |
+
example = "\n".join(example).strip()
|
41 |
+
except TypeError:
|
42 |
+
raise TypeError(f"example: {example}")
|
43 |
+
if train_flag and dialog and (example in dialog or dialog in example):
|
44 |
+
continue
|
45 |
+
|
46 |
+
# example去重
|
47 |
+
if train_flag:
|
48 |
+
# 部分相似也去掉
|
49 |
+
flag = False
|
50 |
+
for n_example in new_similar_examples:
|
51 |
+
if example in n_example or n_example in example:
|
52 |
+
flag = True
|
53 |
+
break
|
54 |
+
if not flag:
|
55 |
+
new_similar_examples.append(example)
|
56 |
+
else:
|
57 |
+
if example not in new_similar_examples:
|
58 |
+
new_similar_examples.append(example)
|
59 |
+
|
60 |
+
results = []
|
61 |
+
total_length = 0
|
62 |
+
for example in new_similar_examples:
|
63 |
+
total_length += len(example) if not total_length else len(example_split_flag) + len(example)
|
64 |
+
if total_length > max_length:
|
65 |
+
break
|
66 |
+
results.append(example)
|
67 |
+
results = example_split_flag.join(results).strip()
|
68 |
+
return results
|
logger.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from logging.handlers import TimedRotatingFileHandler
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import logging
|
7 |
+
|
8 |
+
|
9 |
+
class LoggerFactory:
|
10 |
+
|
11 |
+
@staticmethod
|
12 |
+
def create_logger(name=None, level=logging.INFO):
|
13 |
+
"""create a logger
|
14 |
+
|
15 |
+
Args:
|
16 |
+
name (str): name of the logger
|
17 |
+
level: level of logger
|
18 |
+
|
19 |
+
Raises:
|
20 |
+
ValueError is name is None
|
21 |
+
"""
|
22 |
+
|
23 |
+
if name is None:
|
24 |
+
raise ValueError("name for logger cannot be None")
|
25 |
+
|
26 |
+
formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] "
|
27 |
+
"[%(filename)s:%(lineno)d:%(funcName)s] %(message)s")
|
28 |
+
|
29 |
+
logger_ = logging.getLogger(name)
|
30 |
+
logger_.setLevel(level)
|
31 |
+
logger_.propagate = False
|
32 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
33 |
+
ch.setLevel(level)
|
34 |
+
ch.setFormatter(formatter)
|
35 |
+
logger_.addHandler(ch)
|
36 |
+
return logger_
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def create_logger_with_file(log_file_path: str = None, logger_level=logging.INFO):
|
40 |
+
logger_inner = logging.getLogger()
|
41 |
+
logger_inner.setLevel(logger_level)
|
42 |
+
logger_inner.propagate = True
|
43 |
+
|
44 |
+
formatter = logging.Formatter(fmt="[%(asctime)s] [%(filename)s:%(lineno)s - %(levelname)s] %(message)s",
|
45 |
+
datefmt="%Y-%m-%d %H:%M:%S")
|
46 |
+
|
47 |
+
# TimedRotatingFileHandler
|
48 |
+
if log_file_path:
|
49 |
+
basedir = os.path.dirname(log_file_path)
|
50 |
+
if not os.path.isdir(basedir):
|
51 |
+
os.makedirs(basedir, exist_ok=True)
|
52 |
+
handler_file = TimedRotatingFileHandler(log_file_path, when="d", interval=1, backupCount=30)
|
53 |
+
handler_file.setFormatter(formatter)
|
54 |
+
logger_inner.addHandler(handler_file)
|
55 |
+
|
56 |
+
# StreamHandler
|
57 |
+
handler_console = logging.StreamHandler()
|
58 |
+
handler_console.setFormatter(formatter)
|
59 |
+
logger_inner.addHandler(handler_console)
|
60 |
+
return logger_inner
|
prompt_concat.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from copy import deepcopy
|
3 |
+
from .get_dataset import CreateDataset
|
4 |
+
from .logger import LoggerFactory
|
5 |
+
from .retrieve_dialog import RetrieveDialog
|
6 |
+
from .utils import load_json, load_txt, save_to_json
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
|
12 |
+
|
13 |
+
|
14 |
+
class GetManualTestSamples:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
role_name,
|
18 |
+
role_data_path,
|
19 |
+
save_samples_dir,
|
20 |
+
save_samples_path=None,
|
21 |
+
prompt_path="dataset_character.txt",
|
22 |
+
max_seq_len=4000,
|
23 |
+
retrieve_num=20,
|
24 |
+
):
|
25 |
+
self.role_name = role_name.strip()
|
26 |
+
self.role_data = load_json(role_data_path)
|
27 |
+
self.role_info = self.role_data[0]["role_info"].strip()
|
28 |
+
|
29 |
+
self.prompt = load_txt(prompt_path)
|
30 |
+
self.prompt = self.prompt.replace("${role_name}", self.role_name)
|
31 |
+
self.prompt = self.prompt.replace("${role_info}",
|
32 |
+
f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
|
33 |
+
|
34 |
+
self.retrieve_num = retrieve_num
|
35 |
+
self.retrieve = RetrieveDialog(role_name=self.role_name,
|
36 |
+
raw_dialog_list=[d["dialog"] for d in self.role_data],
|
37 |
+
retrieve_num=retrieve_num)
|
38 |
+
|
39 |
+
self.max_seq_len = max_seq_len
|
40 |
+
if not save_samples_path:
|
41 |
+
save_samples_path = f"{self.role_name}.json"
|
42 |
+
self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)
|
43 |
+
|
44 |
+
def _add_simi_dialog(self, history: list, content_length):
|
45 |
+
retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
|
46 |
+
simi_dialogs = deepcopy(retrieve_results)
|
47 |
+
|
48 |
+
if simi_dialogs:
|
49 |
+
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
|
50 |
+
max_length=self.max_seq_len - content_length,
|
51 |
+
train_flag=False)
|
52 |
+
logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
|
53 |
+
return simi_dialogs, retrieve_results
|
54 |
+
|
55 |
+
def get_qa_samples_by_file(self,
|
56 |
+
questions_path,
|
57 |
+
user_name="user",
|
58 |
+
keep_retrieve_results_flag=False
|
59 |
+
):
|
60 |
+
questions = load_txt(questions_path).splitlines()
|
61 |
+
samples = []
|
62 |
+
for question in questions:
|
63 |
+
question = question.replace('\\n', "\n")
|
64 |
+
query = f"{user_name}:{question}" if ":" not in question else question
|
65 |
+
content = self.prompt.replace("${dialog}", query)
|
66 |
+
content = content.replace("${user_name}", user_name).strip()
|
67 |
+
|
68 |
+
history = [query]
|
69 |
+
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
|
70 |
+
|
71 |
+
sample = {
|
72 |
+
"role_name": self.role_name,
|
73 |
+
"role_info": self.role_info,
|
74 |
+
"user_name": user_name,
|
75 |
+
"dialog": history,
|
76 |
+
"simi_dialogs": simi_dialogs,
|
77 |
+
}
|
78 |
+
if keep_retrieve_results_flag and retrieve_results:
|
79 |
+
sample["retrieve_results"] = retrieve_results
|
80 |
+
samples.append(sample)
|
81 |
+
self._save_samples(samples)
|
82 |
+
|
83 |
+
def get_qa_samples_by_query(self,
|
84 |
+
questions_query,
|
85 |
+
user_name="user",
|
86 |
+
keep_retrieve_results_flag=False
|
87 |
+
):
|
88 |
+
question = questions_query
|
89 |
+
samples = []
|
90 |
+
question = question.replace('\\n', "\n")
|
91 |
+
query = f"{user_name}: {question}" if ":" not in question else question
|
92 |
+
content = self.prompt.replace("${dialog}", query)
|
93 |
+
content = content.replace("${user_name}", user_name).strip()
|
94 |
+
|
95 |
+
history = [query]
|
96 |
+
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
|
97 |
+
|
98 |
+
sample = {
|
99 |
+
"role_name": self.role_name,
|
100 |
+
"role_info": self.role_info,
|
101 |
+
"user_name": user_name,
|
102 |
+
"dialog": history,
|
103 |
+
"simi_dialogs": simi_dialogs,
|
104 |
+
}
|
105 |
+
if keep_retrieve_results_flag and retrieve_results:
|
106 |
+
sample["retrieve_results"] = retrieve_results
|
107 |
+
samples.append(sample)
|
108 |
+
self._save_samples(samples)
|
109 |
+
|
110 |
+
def _save_samples(self, samples):
|
111 |
+
data = samples
|
112 |
+
save_to_json(data, self.save_samples_path)
|
113 |
+
|
114 |
+
|
115 |
+
class CreateTestDataset:
|
116 |
+
def __init__(self,
|
117 |
+
role_name,
|
118 |
+
role_samples_path=None,
|
119 |
+
role_data_path=None,
|
120 |
+
prompt_path="dataset_character.txt",
|
121 |
+
max_seq_len=4000):
|
122 |
+
self.max_seq_len = max_seq_len
|
123 |
+
self.role_name = role_name
|
124 |
+
|
125 |
+
self.prompt = load_txt(prompt_path)
|
126 |
+
self.prompt = self.prompt.replace("${role_name}", role_name).strip()
|
127 |
+
|
128 |
+
if not role_data_path:
|
129 |
+
print("need role_data_path, check please!")
|
130 |
+
self.default_simi_dialogs = None
|
131 |
+
if os.path.exists(role_data_path):
|
132 |
+
data = load_json(role_data_path)
|
133 |
+
role_info = data[0]["role_info"]
|
134 |
+
else:
|
135 |
+
raise ValueError(f"{self.role_name} didn't find role_info.")
|
136 |
+
self.role_info = role_info
|
137 |
+
self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
|
138 |
+
|
139 |
+
if role_samples_path:
|
140 |
+
self.role_samples_path = role_samples_path
|
141 |
+
else:
|
142 |
+
print("check role_samples_path please!")
|
143 |
+
|
144 |
+
def load_samples(self):
|
145 |
+
samples = load_json(self.role_samples_path)
|
146 |
+
results = []
|
147 |
+
for sample in samples:
|
148 |
+
input_text = self.prompt
|
149 |
+
|
150 |
+
simi_dialogs = sample.get("simi_dialogs", None)
|
151 |
+
if not simi_dialogs:
|
152 |
+
simi_dialogs = self.default_simi_dialogs
|
153 |
+
if not simi_dialogs:
|
154 |
+
raise ValueError(f"didn't find simi_dialogs.")
|
155 |
+
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
|
156 |
+
max_length=self.max_seq_len - len(input_text),
|
157 |
+
train_flag=False)
|
158 |
+
|
159 |
+
input_text = input_text.replace("${simi_dialog}", simi_dialogs)
|
160 |
+
user_name = sample.get("user_name", "user")
|
161 |
+
input_text = input_text.replace("${user_name}", user_name)
|
162 |
+
|
163 |
+
dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
|
164 |
+
input_text = input_text.replace("${dialog}", dialog)
|
165 |
+
|
166 |
+
assert len(input_text) < self.max_seq_len
|
167 |
+
results.append({
|
168 |
+
"input_text": input_text,
|
169 |
+
})
|
170 |
+
return results
|
utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
import csv
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def read_csv_to_json(file_path, role_name, role_info):
|
8 |
+
json_list = []
|
9 |
+
|
10 |
+
with open(file_path, mode="r", newline="", encoding="utf-8") as csvfile:
|
11 |
+
csv_reader = csv.reader(csvfile)
|
12 |
+
_ = next(csv_reader)
|
13 |
+
|
14 |
+
for row in csv_reader:
|
15 |
+
json_object = {
|
16 |
+
"role_name": role_name,
|
17 |
+
"role_info": role_info,
|
18 |
+
"dialog": row[1].split("\n"),
|
19 |
+
}
|
20 |
+
json_list.append(json_object)
|
21 |
+
|
22 |
+
return json_list
|
23 |
+
|
24 |
+
|
25 |
+
def save_json(json_list, output_path):
|
26 |
+
with open(output_path, "w", encoding="utf-8") as jsonfile:
|
27 |
+
json.dump(json_list, jsonfile, ensure_ascii=False, indent=4)
|
28 |
+
|
29 |
+
|
30 |
+
def decode_csv_to_json(role_data_path, role_name, role_info, json_output_path):
|
31 |
+
json_data = read_csv_to_json(role_data_path, role_name, role_info)
|
32 |
+
save_json(json_data, json_output_path)
|
33 |
+
|
34 |
+
|
35 |
+
def load_txt(path):
|
36 |
+
with open(path, "r", encoding="utf-8", errors="ignore") as file:
|
37 |
+
text = file.read()
|
38 |
+
return text
|
39 |
+
|
40 |
+
|
41 |
+
def load_json(path):
|
42 |
+
with open(path, "r", encoding="utf-8") as f:
|
43 |
+
data = json.load(f)
|
44 |
+
return data
|
45 |
+
|
46 |
+
|
47 |
+
def save_to_json(data, filepath, flag="w"):
|
48 |
+
if not os.path.exists(os.path.dirname(filepath)):
|
49 |
+
os.makedirs(os.path.dirname(filepath))
|
50 |
+
with open(filepath, flag, encoding="utf-8") as f:
|
51 |
+
f.write(json.dumps(data, ensure_ascii=False, indent=3))
|
52 |
+
|
53 |
+
|
54 |
+
def is_float(my_str):
|
55 |
+
try:
|
56 |
+
num = float(my_str)
|
57 |
+
return True
|
58 |
+
except ValueError:
|
59 |
+
return False
|