Spaces:
Running
Running
singletongue
commited on
Commit
•
5b29d9a
1
Parent(s):
27212ed
Add files
Browse files- app.py +97 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00000-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00001-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00002-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00003-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00004-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00005-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00006-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00007-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00008-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/data-00009-of-00010.arrow +3 -0
- outputs_unsup_simcse/embedded_paragraphs/dataset_info.json +83 -0
- outputs_unsup_simcse/embedded_paragraphs/state.json +40 -0
- outputs_unsup_simcse/encoder/config.json +25 -0
- outputs_unsup_simcse/encoder/pytorch_model.bin +3 -0
- outputs_unsup_simcse/encoder/special_tokens_map.json +7 -0
- outputs_unsup_simcse/encoder/tokenizer_config.json +21 -0
- outputs_unsup_simcse/encoder/vocab.txt +0 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset
|
7 |
+
from transformers import FeatureExtractionPipeline, pipeline
|
8 |
+
|
9 |
+
|
10 |
+
@st.cache_resource
|
11 |
+
def load_encoder_pipeline(encoder_path: str) -> FeatureExtractionPipeline:
|
12 |
+
"""訓練済みの教師なしSimCSEのエンコーダを読み込む"""
|
13 |
+
encoder_pipeline = pipeline("feature-extraction", model=encoder_path)
|
14 |
+
return encoder_pipeline
|
15 |
+
|
16 |
+
|
17 |
+
@st.cache_resource
|
18 |
+
def load_dataset(dataset_dir: str) -> Dataset:
|
19 |
+
"""文埋め込み適用済みのデータセットを読み込み、Faissのインデックスを構築"""
|
20 |
+
# ディスクに保存されたデータセットを読み込む
|
21 |
+
dataset = datasets.load_from_disk(dataset_dir)
|
22 |
+
|
23 |
+
# データセットの"embeddings"フィールドの値からFaissのインデックスを構築する
|
24 |
+
emb_dim = len(dataset[0]["embeddings"])
|
25 |
+
index = faiss.IndexFlatIP(emb_dim)
|
26 |
+
dataset.add_faiss_index("embeddings", custom_index=index)
|
27 |
+
|
28 |
+
return dataset
|
29 |
+
|
30 |
+
|
31 |
+
def embed_text(
|
32 |
+
text: str, encoder_pipeline: FeatureExtractionPipeline
|
33 |
+
) -> np.ndarray:
|
34 |
+
"""教師なしSimCSEのエンコーダを用いてテキストの埋め込みを計算"""
|
35 |
+
with torch.inference_mode():
|
36 |
+
# encoder_pipelineが返すTensorのsizeは(1, トークン数, 埋め込みの次元数)
|
37 |
+
encoded_text = encoder_pipeline(text, return_tensors="pt")[0][0]
|
38 |
+
|
39 |
+
# ベクトルをNumPyのarrayに変換
|
40 |
+
emb = encoded_text.cpu().numpy().astype(np.float32)
|
41 |
+
# ベクトルのノルムが1になるように正規化
|
42 |
+
emb = emb / np.linalg.norm(emb)
|
43 |
+
return emb
|
44 |
+
|
45 |
+
|
46 |
+
def search_similar_texts(
|
47 |
+
query_text: str,
|
48 |
+
dataset: Dataset,
|
49 |
+
encoder_pipeline: FeatureExtractionPipeline,
|
50 |
+
k: int = 5,
|
51 |
+
) -> list[dict[str, float | str]]:
|
52 |
+
"""モデルとデータセットを用いてクエリの類似文検索を実行"""
|
53 |
+
# クエリに対して類似テキストをk件取得する
|
54 |
+
scores, retrieved_examples = dataset.get_nearest_examples(
|
55 |
+
"embeddings", embed_text(query_text, encoder_pipeline), k=k
|
56 |
+
)
|
57 |
+
titles = retrieved_examples["title"]
|
58 |
+
texts = retrieved_examples["text"]
|
59 |
+
|
60 |
+
# 検索された類似テキストをdictのlistにして返す
|
61 |
+
results = [
|
62 |
+
{"score": score, "title": title, "text": text}
|
63 |
+
for score, title, text in zip(scores, titles, texts)
|
64 |
+
]
|
65 |
+
return results
|
66 |
+
|
67 |
+
|
68 |
+
# 訓練済みの教師なしSimCSEのモデルを読み込む
|
69 |
+
encoder_pipeline = load_encoder_pipeline("outputs_unsup_simcse/encoder")
|
70 |
+
|
71 |
+
# 文埋め込み適用済みのデータセットを読み込む
|
72 |
+
dataset = load_dataset("outputs_unsup_simcse/embedded_paragraphs")
|
73 |
+
|
74 |
+
# デモページのタイトルを表示する
|
75 |
+
st.title(":mag: Wikipedia Paragraph Search")
|
76 |
+
|
77 |
+
# デモページのフォームを表示する
|
78 |
+
with st.form("input_form"):
|
79 |
+
# クエリの入力欄を表示し、入力された値を受け取る
|
80 |
+
query_text = st.text_input(
|
81 |
+
"クエリを入力:", value="日本語は、主に日本で話されている言語である。", max_chars=150
|
82 |
+
)
|
83 |
+
# 検索する段落数のスライダーを表示し、設定された値を受け取る
|
84 |
+
k = st.slider("検索する段落数:", min_value=1, max_value=100, value=10)
|
85 |
+
# 検索を実行するボタンを表示し、押下されたらTrueを受け取る
|
86 |
+
is_submitted = st.form_submit_button("Search")
|
87 |
+
|
88 |
+
# 検索結果を表示する
|
89 |
+
if is_submitted and len(query_text) > 0:
|
90 |
+
# クエリに対して類似文検索を実行し、検索結果を受け取る
|
91 |
+
serach_results = search_similar_texts(
|
92 |
+
query_text, dataset, encoder_pipeline, k=k
|
93 |
+
)
|
94 |
+
# 検索結果を表示する
|
95 |
+
st.subheader("検索結果")
|
96 |
+
st.dataframe(serach_results, use_container_width=True)
|
97 |
+
st.caption("セルのダブルクリックで全体が表示されます")
|
outputs_unsup_simcse/embedded_paragraphs/data-00000-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57b90263da12e6f9eaa44d91172dd1f5f015ef6c0c54d61e7d54bccc6b79b759
|
3 |
+
size 458351816
|
outputs_unsup_simcse/embedded_paragraphs/data-00001-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3fde1ef4827099de7d8689bdbf02e3180c5227f0ca2b03e2c24da46bacbb49d
|
3 |
+
size 458002304
|
outputs_unsup_simcse/embedded_paragraphs/data-00002-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91f72a3f71b068f9008a289e4a361cba0a880bb25aa4da5453bb2463d3b3f454
|
3 |
+
size 456771176
|
outputs_unsup_simcse/embedded_paragraphs/data-00003-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5807ae91181b76a6836b65cf5c1092314cc126935be10bafc7e85b79500bc76a
|
3 |
+
size 457297584
|
outputs_unsup_simcse/embedded_paragraphs/data-00004-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:305256fabd2246f74dcd4a980d9ab6c3dced5327e64f7f992f2ee0eebb8a8d18
|
3 |
+
size 456882896
|
outputs_unsup_simcse/embedded_paragraphs/data-00005-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76cb3d35b85b02119e0d5de32782f1b53ce20e166b312f028288b95fdce6e2e5
|
3 |
+
size 456954640
|
outputs_unsup_simcse/embedded_paragraphs/data-00006-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3936501890f32e56a54c5ec091891421e1511e6b0c3d43d7a5511c326182998f
|
3 |
+
size 458542088
|
outputs_unsup_simcse/embedded_paragraphs/data-00007-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81167192200edbb2e2b947d5a6bd0437ddf42b01679bf8f34e3b5067f86ed53a
|
3 |
+
size 457251296
|
outputs_unsup_simcse/embedded_paragraphs/data-00008-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a2bf1d222cf15cb91d8789b7a1bbf17e349c39292303d70a0cdc2d29966d29f
|
3 |
+
size 458474520
|
outputs_unsup_simcse/embedded_paragraphs/data-00009-of-00010.arrow
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:027b919374c985209a89cd99a849de70fb75ffcf3c5c4b610cac21d938c59d3e
|
3 |
+
size 458407928
|
outputs_unsup_simcse/embedded_paragraphs/dataset_info.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"builder_name": "jawiki-paragraphs",
|
3 |
+
"citation": "",
|
4 |
+
"config_name": "default",
|
5 |
+
"dataset_size": 4417130987,
|
6 |
+
"description": "\u66f8\u7c4d\u300e\u5927\u898f\u6a21\u8a00\u8a9e\u30e2\u30c7\u30eb\u5165\u9580\u300f\u3067\u4f7f\u7528\u3059\u308b Wikipedia \u6bb5\u843d\u306e\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3067\u3059\u3002GitHub \u30ea\u30dd\u30b8\u30c8\u30ea singletongue/wikipedia-utils \u3067\u516c\u958b\u3055\u308c\u3066\u3044\u308b\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3092\u5229\u7528\u3057\u3066\u3044\u307e\u3059\u3002",
|
7 |
+
"download_checksums": {
|
8 |
+
"https://github.com/singletongue/wikipedia-utils/releases/download/2023-04-03/paragraphs-jawiki-20230403.json.gz": {
|
9 |
+
"num_bytes": 1489512230,
|
10 |
+
"checksum": null
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"download_size": 1489512230,
|
14 |
+
"features": {
|
15 |
+
"id": {
|
16 |
+
"dtype": "string",
|
17 |
+
"_type": "Value"
|
18 |
+
},
|
19 |
+
"pageid": {
|
20 |
+
"dtype": "int64",
|
21 |
+
"_type": "Value"
|
22 |
+
},
|
23 |
+
"revid": {
|
24 |
+
"dtype": "int64",
|
25 |
+
"_type": "Value"
|
26 |
+
},
|
27 |
+
"paragraph_index": {
|
28 |
+
"dtype": "int64",
|
29 |
+
"_type": "Value"
|
30 |
+
},
|
31 |
+
"title": {
|
32 |
+
"dtype": "string",
|
33 |
+
"_type": "Value"
|
34 |
+
},
|
35 |
+
"section": {
|
36 |
+
"dtype": "string",
|
37 |
+
"_type": "Value"
|
38 |
+
},
|
39 |
+
"text": {
|
40 |
+
"dtype": "string",
|
41 |
+
"_type": "Value"
|
42 |
+
},
|
43 |
+
"html_tag": {
|
44 |
+
"dtype": "string",
|
45 |
+
"_type": "Value"
|
46 |
+
},
|
47 |
+
"embeddings": {
|
48 |
+
"feature": {
|
49 |
+
"dtype": "float32",
|
50 |
+
"_type": "Value"
|
51 |
+
},
|
52 |
+
"_type": "Sequence"
|
53 |
+
}
|
54 |
+
},
|
55 |
+
"homepage": "https://github.com/singletongue/wikipedia-utils",
|
56 |
+
"license": "\u672c\u30c7\u30fc\u30bf\u30bb\u30c3\u30c8\u3067\u4f7f\u7528\u3057\u3066\u3044\u308b Wikipedia \u306e\u30b3\u30f3\u30c6\u30f3\u30c4\u306f\u3001\u30af\u30ea\u30a8\u30a4\u30c6\u30a3\u30d6\u30fb\u30b3\u30e2\u30f3\u30ba\u8868\u793a\u30fb\u7d99\u627f\u30e9\u30a4\u30bb\u30f3\u30b9 3.0 (CC BY-SA 3.0) \u304a\u3088\u3073 GNU \u81ea\u7531\u6587\u66f8\u30e9\u30a4\u30bb\u30f3\u30b9 (GFDL) \u306e\u4e0b\u306b\u914d\u5e03\u3055\u308c\u3066\u3044\u308b\u3082\u306e\u3067\u3059\u3002",
|
57 |
+
"size_in_bytes": 5906643217,
|
58 |
+
"splits": {
|
59 |
+
"train": {
|
60 |
+
"name": "train",
|
61 |
+
"num_bytes": 4417130987,
|
62 |
+
"num_examples": 9668476,
|
63 |
+
"shard_lengths": [
|
64 |
+
984321,
|
65 |
+
1031799,
|
66 |
+
1101914,
|
67 |
+
1132906,
|
68 |
+
1123001,
|
69 |
+
1143878,
|
70 |
+
1138063,
|
71 |
+
1139173,
|
72 |
+
873421
|
73 |
+
],
|
74 |
+
"dataset_name": "jawiki-paragraphs"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"version": {
|
78 |
+
"version_str": "1.0.0",
|
79 |
+
"major": 1,
|
80 |
+
"minor": 0,
|
81 |
+
"patch": 0
|
82 |
+
}
|
83 |
+
}
|
outputs_unsup_simcse/embedded_paragraphs/state.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_data_files": [
|
3 |
+
{
|
4 |
+
"filename": "data-00000-of-00010.arrow"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"filename": "data-00001-of-00010.arrow"
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"filename": "data-00002-of-00010.arrow"
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"filename": "data-00003-of-00010.arrow"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"filename": "data-00004-of-00010.arrow"
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"filename": "data-00005-of-00010.arrow"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"filename": "data-00006-of-00010.arrow"
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"filename": "data-00007-of-00010.arrow"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"filename": "data-00008-of-00010.arrow"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"filename": "data-00009-of-00010.arrow"
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"_fingerprint": "8ff2a1214e978197",
|
35 |
+
"_format_columns": null,
|
36 |
+
"_format_kwargs": {},
|
37 |
+
"_format_type": null,
|
38 |
+
"_output_all_columns": false,
|
39 |
+
"_split": "train"
|
40 |
+
}
|
outputs_unsup_simcse/encoder/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "cl-tohoku/bert-base-japanese-v3",
|
3 |
+
"architectures": [
|
4 |
+
"BertModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"classifier_dropout": null,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"intermediate_size": 3072,
|
13 |
+
"layer_norm_eps": 1e-12,
|
14 |
+
"max_position_embeddings": 512,
|
15 |
+
"model_type": "bert",
|
16 |
+
"num_attention_heads": 12,
|
17 |
+
"num_hidden_layers": 12,
|
18 |
+
"pad_token_id": 0,
|
19 |
+
"position_embedding_type": "absolute",
|
20 |
+
"torch_dtype": "float32",
|
21 |
+
"transformers_version": "4.30.2",
|
22 |
+
"type_vocab_size": 2,
|
23 |
+
"use_cache": true,
|
24 |
+
"vocab_size": 32768
|
25 |
+
}
|
outputs_unsup_simcse/encoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aca39ff56e5bdf8e331de99f48bc049bd2763b327f64457aa98c79bc8e98367e
|
3 |
+
size 444899885
|
outputs_unsup_simcse/encoder/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
outputs_unsup_simcse/encoder/tokenizer_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"clean_up_tokenization_spaces": true,
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"do_lower_case": false,
|
5 |
+
"do_subword_tokenize": true,
|
6 |
+
"do_word_tokenize": true,
|
7 |
+
"jumanpp_kwargs": null,
|
8 |
+
"mask_token": "[MASK]",
|
9 |
+
"mecab_kwargs": {
|
10 |
+
"mecab_dic": "unidic_lite"
|
11 |
+
},
|
12 |
+
"model_max_length": 512,
|
13 |
+
"never_split": null,
|
14 |
+
"pad_token": "[PAD]",
|
15 |
+
"sep_token": "[SEP]",
|
16 |
+
"subword_tokenizer_type": "wordpiece",
|
17 |
+
"sudachi_kwargs": null,
|
18 |
+
"tokenizer_class": "BertJapaneseTokenizer",
|
19 |
+
"unk_token": "[UNK]",
|
20 |
+
"word_tokenizer_type": "mecab"
|
21 |
+
}
|
outputs_unsup_simcse/encoder/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
faiss-cpu
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
transformers[ja,torch]
|