File size: 2,945 Bytes
1704217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import random
from math import sqrt


class NaiveDB:
    def __init__(self):
        self.verbose = False
        self.init_db()

    def init_db(self):
        if self.verbose:
            print("call init_db")
        self.stories = []
        self.norms = []
        self.vecs = []
        self.flags = []  # 用于标记每个story是否可以被搜索
        self.metas = []  # 用于存储每个story的meta信息
        self.last_search_ids = []  # 用于存储上一次搜索的结果

    def build_db(self, stories, vecs, flags=None, metas=None):
        self.stories = stories
        self.vecs = vecs
        self.flags = flags if flags else [True for _ in self.stories]
        self.metas = metas if metas else [{} for _ in self.stories]
        self.recompute_norm()

    def save(self, file_path):
        print(
            "warning! directly save folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead")

    def load(self, file_path):
        print(
            "warning! directly load folder from dbtype NaiveDB has not been implemented yet, try use role_from_hf to load role instead")

    def recompute_norm(self):
        # 补全这部分代码,self.norms 分别存储每个vector的l2 norm
        # 计算每个向量的L2范数
        self.norms = [sqrt(sum([x ** 2 for x in vec])) for vec in self.vecs]

    def get_stories_with_id(self, ids):
        return [self.stories[i] for i in ids]

    def clean_flag(self):
        self.flags = [True for _ in self.stories]

    def disable_story_with_ids(self, close_ids):
        for id in close_ids:
            self.flags[id] = False

    def close_last_search(self):
        for id in self.last_search_ids:
            self.flags[id] = False

    def search(self, query_vector, n_results):

        if self.verbose:
            print("call search")

        if len(self.norms) != len(self.vecs):
            self.recompute_norm()

        # 计算查询向量的范数
        query_norm = sqrt(sum([x ** 2 for x in query_vector]))

        idxs = list(range(len(self.vecs)))

        # 计算余弦相似度
        similarities = []
        for vec, norm, idx in zip(self.vecs, self.norms, idxs):
            if len(self.flags) == len(self.vecs) and not self.flags[idx]:
                continue

            dot_product = sum(q * v for q, v in zip(query_vector, vec))
            if query_norm < 1e-20:
                similarities.append((random.random(), idx))
                continue
            cosine_similarity = dot_product / (query_norm * norm)
            similarities.append((cosine_similarity, idx))

        # 获取最相似的n_results个结果, 使用第0个字段进行排序
        similarities.sort(key=lambda x: x[0], reverse=True)
        self.last_search_ids = [x[1] for x in similarities[:n_results]]

        top_indices = [x[1] for x in similarities[:n_results]]
        return top_indices