File size: 4,983 Bytes
ecca75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b8b1d
 
9462146
ecca75f
9462146
 
f420f35
ecca75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# coding=utf-8
from sentence_transformers import SentenceTransformer
from .utils import load_json

import faiss
import logging
import os
import re
import torch

logger = logging.getLogger(__name__)


class RetrieveDialog:
    def __init__(self,
                 role_name,
                 raw_dialog_list: list = None,
                 retrieve_num=20,
                 min_mean_role_utter_length=10):
        if torch.cuda.is_available():
            gpu_id = 0
            torch.cuda.set_device(gpu_id)

        assert raw_dialog_list

        self.role_name = role_name
        self.min_mean_role_utter_length = min_mean_role_utter_length
        self.retrieve_num = retrieve_num

        # config = load_json("config/config.json")
        # local_dir = config["bge_local_path"]
        # local_dir = os.environ.get('MODEL_PATH', 'BAAI/bge-large-zh-v1.5')

        # if not os.path.exists(local_dir):
        #     print("Please download bge-large-zh-v1.5 first!")
        self.emb_model = SentenceTransformer("BAAI/bge-large-zh-v1.5")

        self.dialogs, self.context_index = self._get_emb_base_by_list(raw_dialog_list)

        logger.info(f"dialog db num: {len(self.dialogs)}")
        logger.info(f"RetrieveDialog init success.")

    @staticmethod
    def dialog_preprocess(dialog: list, role_name):
        dialog_new = []
        # 把人名替换掉,减少对检索的影响
        user_names = []
        role_utter_length = []
        for num in range(len(dialog)):
            utter = dialog[num]
            try:
                user_name, utter_txt = re.split('[::]', utter, maxsplit=1)
            except ValueError as e:
                logging.error(f"utter:{utter} can't find user_name.")
                return None, None

            if user_name != role_name:
                if user_name not in user_names:
                    user_names.append(user_name)
                index = user_names.index(user_name)
                utter = utter.replace(user_name, f"user{index}", 1)
            else:
                role_utter_length.append(len(utter_txt))
            dialog_new.append(utter)
        return dialog_new, user_names, role_utter_length

    def _get_emb_base_by_list(self, raw_dialog_list):
        logger.info(f"raw dialog db num: {len(raw_dialog_list)}")
        new_raw_dialog_list = []
        context_list = []

        # 为了兼容因为句长把所有对话都过滤掉的情况
        new_raw_dialog_list_total = []
        context_list_total = []
        for raw_dialog in raw_dialog_list:
            if not raw_dialog:
                continue

            end = 0
            for x in raw_dialog[::-1]:
                if x.startswith(self.role_name):
                    break
                end += 1

            raw_dialog = raw_dialog[:len(raw_dialog) - end]
            new_dialog, user_names, role_utter_length = self.dialog_preprocess(raw_dialog, self.role_name)
            if not new_dialog or not role_utter_length:
                continue

            if raw_dialog in new_raw_dialog_list_total:
                continue

            # 获得embedding时,不需要最后一句答案
            context = "\n".join(new_dialog) if len(new_dialog) < 2 else "\n".join(new_dialog[:-1])

            new_raw_dialog_list_total.append(raw_dialog)
            context_list_total.append(context)

            # 句长过滤
            role_length_mean = sum(role_utter_length) / len(role_utter_length)
            if role_length_mean < self.min_mean_role_utter_length:
                continue
            new_raw_dialog_list.append(raw_dialog)
            context_list.append(context)

        assert len(new_raw_dialog_list) == len(context_list)
        logger.debug(f"new_raw_dialog num: {len(new_raw_dialog_list)}")

        # 兼容样本过少的情况
        if len(new_raw_dialog_list) < self.retrieve_num:
            new_raw_dialog_list = new_raw_dialog_list_total
            context_list = context_list_total

        # 对话向量库
        context_vectors = self.emb_model.encode(context_list, normalize_embeddings=True)
        context_index = faiss.IndexFlatL2(context_vectors.shape[1])
        context_index.add(context_vectors)

        return new_raw_dialog_list, context_index

    def get_retrieve_res(self, dialog: list, retrieve_num: int):
        logger.debug(f"dialog: {dialog}")

        # 同样去掉user name影响
        dialog, _, _ = self.dialog_preprocess(dialog, self.role_name)
        dialog_vector = self.emb_model.encode(["\n".join(dialog)], normalize_embeddings=True)

        simi_dialog_distance, simi_dialog_index = self.context_index.search(
            dialog_vector, min(retrieve_num, len(self.dialogs)))
        simi_dialog_results = [
            (str(simi_dialog_distance[0][num]), self.dialogs[index]) for num, index in enumerate(simi_dialog_index[0])
        ]
        logger.debug(f"dialog retrieve res: {simi_dialog_results}")

        return simi_dialog_results