Shakshi3104 commited on
Commit
2cca64b
1 Parent(s): 83ce2de

[add] Implement surface search

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. model/search/bm25.py +146 -0
  3. model/utils/tokenizer.py +63 -0
.gitignore CHANGED
@@ -5,6 +5,7 @@
5
  # Develop
6
  .venv/
7
  logs/
 
8
 
9
  # Default
10
  # Byte-compiled / optimized / DLL files
 
5
  # Develop
6
  .venv/
7
  logs/
8
+ data/
9
 
10
  # Default
11
  # Byte-compiled / optimized / DLL files
model/search/bm25.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import List, Union
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+
10
+ from rank_bm25 import BM25Okapi
11
+
12
+ from model.search.base import BaseSearchClient
13
+ from model.utils.tokenizer import MeCabTokenizer
14
+
15
+
16
+ class BM25Wrapper(BM25Okapi):
17
+ def __init__(self, dataset: pd.DataFrame, target, tokenizer=None, k1=1.5, b=0.75, epsilon=0.25):
18
+ self.k1 = k1
19
+ self.b = b
20
+ self.epsilon = epsilon
21
+ self.dataset = dataset
22
+ corpus = dataset[target].values.tolist()
23
+ super().__init__(corpus, tokenizer)
24
+
25
+ def get_top_n(self, query, documents, n=5):
26
+ assert self.corpus_size == len(documents), "The documents given don't match the index corpus!"
27
+
28
+ scores = self.get_scores(query)
29
+ top_n = np.argsort(scores)[::-1][:n]
30
+
31
+ result = deepcopy(self.dataset.iloc[top_n])
32
+ result["score"] = scores[top_n]
33
+ return result
34
+
35
+
36
+ class BM25SearchClient(BaseSearchClient):
37
+ def __init__(self, _model: BM25Okapi, _corpus: List[List[str]]):
38
+ """
39
+
40
+ Parameters
41
+ ----------
42
+ _model:
43
+ BM25Okapi
44
+ _corpus:
45
+ List[List[str]], 検索対象の分かち書き後のフィールド
46
+ """
47
+ self.model = _model
48
+ self.corpus = _corpus
49
+
50
+ @staticmethod
51
+ def tokenize_ja(_text: List[str]):
52
+ """MeCab日本語分かち書きによるコーパス作成
53
+
54
+ Args:
55
+ _text (List[str]): コーパス文のリスト
56
+
57
+ Returns:
58
+ List[List[str]]: 分かち書きされたテキストのリスト
59
+ """
60
+
61
+ # MeCabで分かち書き
62
+ parser = MeCabTokenizer.from_tagger("-Owakati")
63
+
64
+ corpus = []
65
+ with tqdm(_text) as pbar:
66
+ for i, t in enumerate(pbar):
67
+ try:
68
+ # 分かち書きをする
69
+ corpus.append(parser.parse(t).split())
70
+ except TypeError as e:
71
+ if not isinstance(t, str):
72
+ logger.info(f"🚦 [BM25SearchClient] Corpus index of {i} is not instance of String.")
73
+ corpus.append(["[UNKNOWN]"])
74
+ else:
75
+ raise e
76
+ return corpus
77
+
78
+ @classmethod
79
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
80
+ """
81
+ 検索ドキュメントのpd.DataFrameから初期化する
82
+
83
+ Parameters
84
+ ----------
85
+ _data:
86
+ pd.DataFrame, 検索対象のDataFrame
87
+
88
+ _target:
89
+ str, 検索対象のカラム名
90
+
91
+ Returns
92
+ -------
93
+
94
+ """
95
+
96
+ logger.info("🚦 [BM25SearchClient] Initialize from DataFrame")
97
+
98
+ search_field = _data[_target]
99
+ corpus = search_field.values.tolist()
100
+
101
+ # 分かち書きをする
102
+ corpus_tokenized = cls.tokenize_ja(corpus)
103
+ _data["tokenized"] = corpus_tokenized
104
+
105
+ bm25 = BM25Wrapper(_data, "tokenized")
106
+ return cls(bm25, corpus_tokenized)
107
+
108
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
109
+ """
110
+ クエリに対する検索結果をtop-n個取得する
111
+
112
+ Parameters
113
+ ----------
114
+ _query:
115
+ Union[List[str], str], 検索クエリ
116
+ n:
117
+ int, top-nの個数. デフォルト 10.
118
+
119
+ Returns
120
+ -------
121
+ results:
122
+ List[pd.DataFrame], ランキング結果
123
+ """
124
+
125
+ logger.info(f"🚦 [BM25SearchClient] Search top {n} | {_query}")
126
+
127
+ # 型チェック
128
+ if isinstance(_query, str):
129
+ _query = [_query]
130
+
131
+ # クエリを分かち書き
132
+ query_tokens = self.tokenize_ja(_query)
133
+
134
+ # ランキングtop-nをクエリ毎に取得
135
+ result = []
136
+ for query in tqdm(query_tokens):
137
+ query_text = "".join(query)
138
+ df_res = self.model.get_top_n(query, self.corpus, n)
139
+ df_res["query"] = [query_text] * len(df_res)
140
+ df_res["rank"] = deepcopy(df_res.reset_index()).index
141
+ df_res = df_res.drop(columns=["tokenized"])
142
+ result.append(df_res)
143
+
144
+ logger.success(f"🚦 [BM25SearchClient] Executed")
145
+
146
+ return result
model/utils/tokenizer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import Optional
3
+
4
+ import MeCab
5
+ # from janome.tokenizer import Tokenizer
6
+
7
+
8
+ class BaseTokenizer:
9
+ @abc.abstractmethod
10
+ def parse(self, _text: str) -> str:
11
+ """
12
+ 分かち書きした結果を返す
13
+
14
+ Parameters
15
+ ----------
16
+ _text:
17
+ str, 入力文章
18
+
19
+ Returns
20
+ -------
21
+ parsed:
22
+ str, 分かち書き後の文章, スペース区切り
23
+ """
24
+ raise NotImplementedError
25
+
26
+
27
+ class MeCabTokenizer(BaseTokenizer):
28
+ def __init__(self, _parser: MeCab.Tagger) -> None:
29
+ self.parser = _parser
30
+
31
+ @classmethod
32
+ def from_tagger(cls, _tagger: Optional[str]):
33
+ parser = MeCab.Tagger(_tagger)
34
+ return cls(parser)
35
+
36
+ def parse(self, _text: str):
37
+ return self.parser.parse(_text)
38
+
39
+
40
+ # class JanomeTokenizer(BaseTokenizer):
41
+ # def __init__(self, _tokenizer: Tokenizer):
42
+ # self.tokenizer = _tokenizer
43
+ #
44
+ # @classmethod
45
+ # def from_user_simple_dictionary(cls, _dict_filepath: Optional[str] = None):
46
+ # """
47
+ # 簡易辞書フォーマットによるユーザー辞書によるイニシャライザー
48
+ #
49
+ # https://mocobeta.github.io/janome/#v0-2-7
50
+ #
51
+ # Parameters
52
+ # ----------
53
+ # _dict_filepath:
54
+ # str, 簡易辞書フォーマットで書かれたユーザー辞書 (CSVファイル)のファイルパス
55
+ # """
56
+ #
57
+ # if _dict_filepath is None:
58
+ # return cls(Tokenizer())
59
+ # else:
60
+ # return cls(Tokenizer(udic=_dict_filepath, udic_type='simpledic'))
61
+ #
62
+ # def parse(self, _text: str) -> str:
63
+ # return " ".join(list(self.tokenizer.tokenize(_text, wakati=True)))