|
""" |
|
This script contains an example how to perform semantic search with PyTorch. It performs exact nearest neighborh search. |
|
|
|
As dataset, we use the Quora Duplicate Questions dataset, which contains about 500k questions (we only use about 100k): |
|
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs |
|
|
|
|
|
As embeddings model, we use the SBERT model 'quora-distilbert-multilingual', |
|
that it aligned for 100 languages. I.e., you can type in a question in various languages and it will |
|
return the closest questions in the corpus (questions in the corpus are mainly in English). |
|
|
|
|
|
Google Colab example: https://colab.research.google.com/drive/12cn5Oo0v3HfQQ8Tv6-ukgxXSmT3zl35A?usp=sharing |
|
""" |
|
from sentence_transformers import SentenceTransformer, util |
|
import os |
|
import csv |
|
import pickle |
|
import time |
|
|
|
|
|
model_name = 'quora-distilbert-multilingual' |
|
model = SentenceTransformer(model_name) |
|
|
|
url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" |
|
dataset_path = "quora_duplicate_questions.tsv" |
|
max_corpus_size = 100000 |
|
|
|
|
|
embedding_cache_path = 'quora-embeddings-{}-size-{}.pkl'.format(model_name.replace('/', '_'), max_corpus_size) |
|
|
|
|
|
|
|
if not os.path.exists(embedding_cache_path): |
|
|
|
|
|
if not os.path.exists(dataset_path): |
|
print("Download dataset") |
|
util.http_get(url, dataset_path) |
|
|
|
|
|
corpus_sentences = set() |
|
with open(dataset_path, encoding='utf8') as fIn: |
|
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_MINIMAL) |
|
for row in reader: |
|
corpus_sentences.add(row['question1']) |
|
if len(corpus_sentences) >= max_corpus_size: |
|
break |
|
|
|
corpus_sentences.add(row['question2']) |
|
if len(corpus_sentences) >= max_corpus_size: |
|
break |
|
|
|
corpus_sentences = list(corpus_sentences) |
|
print("Encode the corpus. This might take a while") |
|
corpus_embeddings = model.encode(corpus_sentences, show_progress_bar=True, convert_to_tensor=True) |
|
|
|
print("Store file on disc") |
|
with open(embedding_cache_path, "wb") as fOut: |
|
pickle.dump({'sentences': corpus_sentences, 'embeddings': corpus_embeddings}, fOut) |
|
else: |
|
print("Load pre-computed embeddings from disc") |
|
with open(embedding_cache_path, "rb") as fIn: |
|
cache_data = pickle.load(fIn) |
|
corpus_sentences = cache_data['sentences'][0:max_corpus_size] |
|
corpus_embeddings = cache_data['embeddings'][0:max_corpus_size] |
|
|
|
|
|
print("Corpus loaded with {} sentences / embeddings".format(len(corpus_sentences))) |
|
|
|
|
|
corpus_embeddings = corpus_embeddings.to(model._target_device) |
|
|
|
while True: |
|
inp_question = input("Please enter a question: ") |
|
|
|
start_time = time.time() |
|
question_embedding = model.encode(inp_question, convert_to_tensor=True) |
|
hits = util.semantic_search(question_embedding, corpus_embeddings) |
|
end_time = time.time() |
|
hits = hits[0] |
|
|
|
print("Input question:", inp_question) |
|
print("Results (after {:.3f} seconds):".format(end_time-start_time)) |
|
for hit in hits[0:5]: |
|
print("\t{:.3f}\t{}".format(hit['score'], corpus_sentences[hit['corpus_id']])) |
|
|
|
print("\n\n========\n") |
|
|