File size: 4,418 Bytes
6d23787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import Union, List

import pandas as pd
from copy import deepcopy

from dotenv import load_dotenv
from loguru import logger
from tqdm import tqdm

from model.search.base import BaseSearchClient
from model.search.surface import BM25SearchClient
from model.search.vector import RuriVoyagerSearchClient

from model.utils.timer import stop_watch


def reciprocal_rank_fusion(sparse: pd.DataFrame, dense: pd.DataFrame, k=60) -> pd.DataFrame:
    """
    Reciprocal Rank Fusionを計算する

    Notes
    ----------
    RRFの計算は以下の式

    .. math:: RRF = \sum_{i=1}^n \frac{1}{k+r_i}

    Parameters
    ----------
    sparse:
        pd.DataFrame, 表層検索の検索結果
    dense:
        pd.DataFrame, ベクトル検索の結果
    k:
        int,

    Returns
    -------
    rank_results:
        pd.DataFrame, RRFによるリランク結果

    """
    # カラム名を変更
    sparse = sparse.rename(columns={"rank": "rank_sparse"})
    dense = dense.rename(columns={"rank": "rank_dense"})
    # denseはランク以外を落として結合する
    dense_ = dense["rank_dense"]

    # 順位を1からスタートするようにする
    sparse["rank_sparse"] += 1
    dense_ += 1

    # 文書のインデックスをキーに結合する
    rank_results = pd.merge(sparse, dense_, how="left", left_index=True, right_index=True)

    # RRFスコアの計算
    rank_results["rrf_score"] = 1 / (rank_results["rank_dense"] + k) + 1 / (rank_results["rank_sparse"] + k)

    # RRFスコアのスコアが大きい順にソート
    rank_results = rank_results.sort_values(["rrf_score"], ascending=False)
    rank_results["rank"] = deepcopy(rank_results.reset_index()).index

    return rank_results


class HybridSearchClient(BaseSearchClient):
    def __init__(self, dense_model: BaseSearchClient, sparse_model: BaseSearchClient):
        self.dense_model = dense_model
        self.sparse_model = sparse_model

    @classmethod
    @stop_watch
    def from_dataframe(cls, _data: pd.DataFrame, _target: str):
        """
        検索ドキュメントのpd.DataFrameから初期化する

        Parameters
        ----------
        _data:
            pd.DataFrame, 検索対象のDataFrame

        _target:
            str, 検索対象のカラム名

        Returns
        -------

        """
        # 表層検索の初期化
        dense_model = BM25SearchClient.from_dataframe(_data, _target)
        # ベクトル検索の初期化
        sparse_model = RuriVoyagerSearchClient.from_dataframe(_data, _target)

        return cls(dense_model, sparse_model)

    @stop_watch
    def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
        """
        クエリに対する検索結果をtop-n個取得する

        Parameters
        ----------
        _query:
            Union[List[str], str], 検索クエリ
        n:
            int, top-nの個数. デフォルト 10.

        Returns
        -------
        results:
            List[pd.DataFrame], ランキング結果
        """

        logger.info(f"🚦 [HybridSearchClient] Search top {n} | {_query}")

        # 型チェック
        if isinstance(_query, str):
            _query = [_query]

        # ランキングtop-nをクエリ毎に取得
        result = []
        for query in tqdm(_query):
            assert len(self.sparse_model.corpus) == len(
                self.dense_model.corpus), "The document counts do not match between sparse and dense!"

            # ドキュメント数
            doc_num = len(self.sparse_model.corpus)

            # 表層検索
            logger.info(f"🚦 [HybridSearchClient] run surface search ...")
            sparse_res = self.sparse_model.search_top_n(query, n=doc_num)
            # ベクトル検索
            logger.info(f"🚦 [HybridSearchClient] run vector search ...")
            dense_res = self.dense_model.search_top_n(query, n=doc_num)

            # RRFスコアの計算
            logger.info(f"🚦 [HybridSearchClient] compute RRF scores ...")
            rrf_res = reciprocal_rank_fusion(sparse_res[0], dense_res[0])

            # 結果をtop Nに絞る
            top_num = 10
            rrf_res = rrf_res.head(top_num)
            logger.info(f"🚦 [HybridSearchClient] return {top_num} results")

            result.append(rrf_res)

        return result