bingnoi commited on
Commit
aca2cb2
1 Parent(s): e151a5f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. retrieve_dialog.py +135 -0
app.py CHANGED
@@ -31,8 +31,6 @@ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is
31
  MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
32
  TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
33
 
34
- print(MODEL_PATH,TOKENIZER_PATH)
35
-
36
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
37
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto",
38
  trust_remote_code=True)
 
31
  MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
32
  TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
33
 
 
 
34
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
35
  model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto",
36
  trust_remote_code=True)
retrieve_dialog.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from sentence_transformers import SentenceTransformer
3
+ from .utils import load_json
4
+
5
+ import faiss
6
+ import logging
7
+ import os
8
+ import re
9
+ import torch
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RetrieveDialog:
15
+ def __init__(self,
16
+ role_name,
17
+ raw_dialog_list: list = None,
18
+ retrieve_num=20,
19
+ min_mean_role_utter_length=10):
20
+ if torch.cuda.is_available():
21
+ gpu_id = 0
22
+ torch.cuda.set_device(gpu_id)
23
+
24
+ assert raw_dialog_list
25
+
26
+ self.role_name = role_name
27
+ self.min_mean_role_utter_length = min_mean_role_utter_length
28
+ self.retrieve_num = retrieve_num
29
+
30
+ # config = load_json("config/config.json")
31
+ # local_dir = config["bge_local_path"]
32
+ local_dir = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
33
+
34
+ if not os.path.exists(local_dir):
35
+ print("Please download bge-large-zh-v1.5 first!")
36
+ self.emb_model = SentenceTransformer(local_dir)
37
+
38
+ self.dialogs, self.context_index = self._get_emb_base_by_list(raw_dialog_list)
39
+
40
+ logger.info(f"dialog db num: {len(self.dialogs)}")
41
+ logger.info(f"RetrieveDialog init success.")
42
+
43
+ @staticmethod
44
+ def dialog_preprocess(dialog: list, role_name):
45
+ dialog_new = []
46
+ # 把人名替换掉,减少对检索的影响
47
+ user_names = []
48
+ role_utter_length = []
49
+ for num in range(len(dialog)):
50
+ utter = dialog[num]
51
+ try:
52
+ user_name, utter_txt = re.split('[::]', utter, maxsplit=1)
53
+ except ValueError as e:
54
+ logging.error(f"utter:{utter} can't find user_name.")
55
+ return None, None
56
+
57
+ if user_name != role_name:
58
+ if user_name not in user_names:
59
+ user_names.append(user_name)
60
+ index = user_names.index(user_name)
61
+ utter = utter.replace(user_name, f"user{index}", 1)
62
+ else:
63
+ role_utter_length.append(len(utter_txt))
64
+ dialog_new.append(utter)
65
+ return dialog_new, user_names, role_utter_length
66
+
67
+ def _get_emb_base_by_list(self, raw_dialog_list):
68
+ logger.info(f"raw dialog db num: {len(raw_dialog_list)}")
69
+ new_raw_dialog_list = []
70
+ context_list = []
71
+
72
+ # 为了兼容因为句长把所有对话都过滤掉的情况
73
+ new_raw_dialog_list_total = []
74
+ context_list_total = []
75
+ for raw_dialog in raw_dialog_list:
76
+ if not raw_dialog:
77
+ continue
78
+
79
+ end = 0
80
+ for x in raw_dialog[::-1]:
81
+ if x.startswith(self.role_name):
82
+ break
83
+ end += 1
84
+
85
+ raw_dialog = raw_dialog[:len(raw_dialog) - end]
86
+ new_dialog, user_names, role_utter_length = self.dialog_preprocess(raw_dialog, self.role_name)
87
+ if not new_dialog or not role_utter_length:
88
+ continue
89
+
90
+ if raw_dialog in new_raw_dialog_list_total:
91
+ continue
92
+
93
+ # 获得embedding时,不需要最后一句答案
94
+ context = "\n".join(new_dialog) if len(new_dialog) < 2 else "\n".join(new_dialog[:-1])
95
+
96
+ new_raw_dialog_list_total.append(raw_dialog)
97
+ context_list_total.append(context)
98
+
99
+ # 句长过滤
100
+ role_length_mean = sum(role_utter_length) / len(role_utter_length)
101
+ if role_length_mean < self.min_mean_role_utter_length:
102
+ continue
103
+ new_raw_dialog_list.append(raw_dialog)
104
+ context_list.append(context)
105
+
106
+ assert len(new_raw_dialog_list) == len(context_list)
107
+ logger.debug(f"new_raw_dialog num: {len(new_raw_dialog_list)}")
108
+
109
+ # 兼容样本过少的情况
110
+ if len(new_raw_dialog_list) < self.retrieve_num:
111
+ new_raw_dialog_list = new_raw_dialog_list_total
112
+ context_list = context_list_total
113
+
114
+ # 对话向量库
115
+ context_vectors = self.emb_model.encode(context_list, normalize_embeddings=True)
116
+ context_index = faiss.IndexFlatL2(context_vectors.shape[1])
117
+ context_index.add(context_vectors)
118
+
119
+ return new_raw_dialog_list, context_index
120
+
121
+ def get_retrieve_res(self, dialog: list, retrieve_num: int):
122
+ logger.debug(f"dialog: {dialog}")
123
+
124
+ # 同样去掉user name影响
125
+ dialog, _, _ = self.dialog_preprocess(dialog, self.role_name)
126
+ dialog_vector = self.emb_model.encode(["\n".join(dialog)], normalize_embeddings=True)
127
+
128
+ simi_dialog_distance, simi_dialog_index = self.context_index.search(
129
+ dialog_vector, min(retrieve_num, len(self.dialogs)))
130
+ simi_dialog_results = [
131
+ (str(simi_dialog_distance[0][num]), self.dialogs[index]) for num, index in enumerate(simi_dialog_index[0])
132
+ ]
133
+ logger.debug(f"dialog retrieve res: {simi_dialog_results}")
134
+
135
+ return simi_dialog_results