simcse-demo / app.py
singletongue's picture
Add files
5b29d9a
raw
history blame
No virus
3.86 kB
import datasets
import faiss
import numpy as np
import streamlit as st
import torch
from datasets import Dataset
from transformers import FeatureExtractionPipeline, pipeline
@st.cache_resource
def load_encoder_pipeline(encoder_path: str) -> FeatureExtractionPipeline:
"""訓練済みの教師なしSimCSEのエンコーダを読み込む"""
encoder_pipeline = pipeline("feature-extraction", model=encoder_path)
return encoder_pipeline
@st.cache_resource
def load_dataset(dataset_dir: str) -> Dataset:
"""文埋め込み適用済みのデータセットを読み込み、Faissのインデックスを構築"""
# ディスクに保存されたデータセットを読み込む
dataset = datasets.load_from_disk(dataset_dir)
# データセットの"embeddings"フィールドの値からFaissのインデックスを構築する
emb_dim = len(dataset[0]["embeddings"])
index = faiss.IndexFlatIP(emb_dim)
dataset.add_faiss_index("embeddings", custom_index=index)
return dataset
def embed_text(
text: str, encoder_pipeline: FeatureExtractionPipeline
) -> np.ndarray:
"""教師なしSimCSEのエンコーダを用いてテキストの埋め込みを計算"""
with torch.inference_mode():
# encoder_pipelineが返すTensorのsizeは(1, トークン数, 埋め込みの次元数)
encoded_text = encoder_pipeline(text, return_tensors="pt")[0][0]
# ベクトルをNumPyのarrayに変換
emb = encoded_text.cpu().numpy().astype(np.float32)
# ベクトルのノルムが1になるように正規化
emb = emb / np.linalg.norm(emb)
return emb
def search_similar_texts(
query_text: str,
dataset: Dataset,
encoder_pipeline: FeatureExtractionPipeline,
k: int = 5,
) -> list[dict[str, float | str]]:
"""モデルとデータセットを用いてクエリの類似文検索を実行"""
# クエリに対して類似テキストをk件取得する
scores, retrieved_examples = dataset.get_nearest_examples(
"embeddings", embed_text(query_text, encoder_pipeline), k=k
)
titles = retrieved_examples["title"]
texts = retrieved_examples["text"]
# 検索された類似テキストをdictのlistにして返す
results = [
{"score": score, "title": title, "text": text}
for score, title, text in zip(scores, titles, texts)
]
return results
# 訓練済みの教師なしSimCSEのモデルを読み込む
encoder_pipeline = load_encoder_pipeline("outputs_unsup_simcse/encoder")
# 文埋め込み適用済みのデータセットを読み込む
dataset = load_dataset("outputs_unsup_simcse/embedded_paragraphs")
# デモページのタイトルを表示する
st.title(":mag: Wikipedia Paragraph Search")
# デモページのフォームを表示する
with st.form("input_form"):
# クエリの入力欄を表示し、入力された値を受け取る
query_text = st.text_input(
"クエリを入力:", value="日本語は、主に日本で話されている言語である。", max_chars=150
)
# 検索する段落数のスライダーを表示し、設定された値を受け取る
k = st.slider("検索する段落数:", min_value=1, max_value=100, value=10)
# 検索を実行するボタンを表示し、押下されたらTrueを受け取る
is_submitted = st.form_submit_button("Search")
# 検索結果を表示する
if is_submitted and len(query_text) > 0:
# クエリに対して類似文検索を実行し、検索結果を受け取る
serach_results = search_similar_texts(
query_text, dataset, encoder_pipeline, k=k
)
# 検索結果を表示する
st.subheader("検索結果")
st.dataframe(serach_results, use_container_width=True)
st.caption("セルのダブルクリックで全体が表示されます")