Shakshi3104 commited on
Commit
3e4b2ef
1 Parent(s): 85c3441

[add] Implement vector search with Ruri and Voyager

Browse files
Files changed (1) hide show
  1. model/search/ruri.py +165 -0
model/search/ruri.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import List, Union, Optional
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from copy import deepcopy
10
+ from dotenv import load_dotenv
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ import sentence_transformers as st
15
+ from sentence_transformers import util as st_util
16
+
17
+ import voyager
18
+
19
+ from model.search.base import BaseSearchClient
20
+
21
+
22
+ def array_to_string(array: np.ndarray) -> str:
23
+ """
24
+ np.ndarrayを文字列に変換する
25
+
26
+ Parameters
27
+ ----------
28
+ array:
29
+ np.ndarray
30
+
31
+ Returns
32
+ -------
33
+ array_string:
34
+ str
35
+ """
36
+ array_string = f"{array.tolist()}"
37
+ return array_string
38
+
39
+
40
+ class RuriEmbedder:
41
+ def __init__(self, model: Optional[st.SentenceTransformer] = None):
42
+
43
+ load_dotenv()
44
+
45
+ # モデルの保存先
46
+ self.model_dir = Path("models/ruri")
47
+ model_filepath = self.model_dir / "ruri-large"
48
+
49
+ # モデル
50
+ if model is None:
51
+ if model_filepath.exists():
52
+ logger.info(f"🚦 [RuriEmbedder] load ruri-large from local path: {model_filepath}")
53
+ self.model = st.SentenceTransformer(str(model_filepath))
54
+ else:
55
+ logger.info(f"🚦 [RuriEmbedder] load ruri-large from HuggingFace🤗")
56
+ token = os.getenv("HF_TOKEN")
57
+ self.model = st.SentenceTransformer("cl-nagoya/ruri-large", token=token)
58
+ # モデルを保存する
59
+ logger.info(f"🚦 [RuriEmbedder] save model ...")
60
+ self.model.save(str(model_filepath))
61
+ else:
62
+ self.model = model
63
+
64
+ def embed(self, text: Union[str, list[str]]) -> np.ndarray:
65
+ """
66
+
67
+ Parameters
68
+ ----------
69
+ text:
70
+ Union[str, list[str]], ベクトル化する文字列
71
+
72
+ Returns
73
+ -------
74
+ embedding:
75
+ np.ndarray, 埋め込み表現. トークンサイズ 1024
76
+ """
77
+ embedding = self.model.encode(text, convert_to_numpy=True)
78
+ return embedding
79
+
80
+
81
+ class RuriVoyagerSearchClient(BaseSearchClient):
82
+ def __init__(self, dataset: pd.DataFrame, target: str,
83
+ index: voyager.Index,
84
+ model: RuriEmbedder):
85
+ load_dotenv()
86
+ # オリジナルのコーパス
87
+ self.dataset = dataset
88
+ self.corpus = dataset[target].values.tolist()
89
+
90
+ # 埋め込みモデル
91
+ self.embedder = model
92
+
93
+ # Voyagerインデックス
94
+ self.index = index
95
+
96
+ @classmethod
97
+ def from_dataframe(cls, _data: pd.DataFrame, _target: str):
98
+ logger.info("🚦 [RuriVoyagerSearchClient] Initialize from DataFrame")
99
+
100
+ search_field = _data[_target]
101
+ corpus = search_field.values.tolist()
102
+
103
+ # 埋め込みモデルの初期化
104
+ embedder = RuriEmbedder()
105
+
106
+ # Ruriの前処理
107
+ corpus = [f"文章: {c}" for c in corpus]
108
+
109
+ # ベクトル化する
110
+ embeddings = embedder.embed(corpus)
111
+
112
+ # 埋め込みベクトルの次元
113
+ num_dim = embeddings.shape[1]
114
+ logger.debug(f"🚦⚓️ [RuriVoyagerSearchClient] Number of dimensions of Embedding vector is {num_dim}")
115
+
116
+ # Voyagerのインデックスを初期化
117
+ index = voyager.Index(voyager.Space.Cosine, num_dimensions=num_dim)
118
+ # indexにベクトルを追加
119
+ _ = index.add_items(embeddings)
120
+
121
+ return cls(_data, _target, index, embedder)
122
+
123
+ def search_top_n(self, _query: Union[List[str], str], n: int = 10) -> List[pd.DataFrame]:
124
+ """
125
+ クエリに対する検索結果をtop-n個取得する
126
+
127
+ Parameters
128
+ ----------
129
+ _query:
130
+ Union[List[str], str], 検索クエリ
131
+ n:
132
+ int, top-nの個数. デフォルト 10.
133
+
134
+ Returns
135
+ -------
136
+ results:
137
+ List[pd.DataFrame], ランキング結果
138
+ """
139
+
140
+ logger.info(f"🚦 [RuriVoyagerSearchClient] Search top {n} | {_query}")
141
+
142
+ # 型チェック
143
+ if isinstance(_query, str):
144
+ _query = [_query]
145
+
146
+ # Ruriの前処理
147
+ _query = [f"クエリ: {q}" for q in _query]
148
+
149
+ # ベクトル化
150
+ embeddings_queries = self.embedder.embed(_query)
151
+
152
+ # ランキングtop-nをクエリ毎に取得
153
+ result = []
154
+ for embeddings_query in tqdm(embeddings_queries):
155
+ # Voyagerのインデックスを探索
156
+ neighbors_indices, distances = self.index.query(embeddings_query, k=n)
157
+ # 類似度スコア
158
+ df_res = deepcopy(self.dataset.iloc[neighbors_indices])
159
+ df_res["score"] = distances
160
+ # ランク
161
+ df_res["rank"] = deepcopy(df_res.reset_index()).index
162
+
163
+ result.append(df_res)
164
+
165
+ return result