|
from sense2vec import Sense2Vec |
|
from sentence_transformers import SentenceTransformer |
|
import wget |
|
import os |
|
from .mmr import mmr |
|
|
|
url = 'https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz' |
|
cmd = 'tar -xvf {}' |
|
|
|
class S2V: |
|
def __init__(self): |
|
self.model= SentenceTransformer('all-MiniLM-L12-v2') |
|
filename = wget.download(url) |
|
os.system(cmd.format(filename)) |
|
self.s2v = Sense2Vec().from_disk('s2v_old') |
|
|
|
def removeDuplicates(self, most_similar, originalword): |
|
distractors = [] |
|
|
|
for each_word in most_similar: |
|
append_word = each_word[0].split("|")[0].replace("_", " ") |
|
if append_word not in distractors and append_word != originalword: |
|
distractors.append(append_word) |
|
return distractors |
|
|
|
def get_answer_and_distractor_embeddings(self,answer,candidate_distractors): |
|
answer_embedding = self.model.encode([answer]) |
|
distractor_embeddings = self.model.encode(candidate_distractors) |
|
return answer_embedding,distractor_embeddings |
|
|
|
def execute(self, originalword): |
|
try: |
|
word = originalword.lower() |
|
word = word.replace(" ", "_") |
|
|
|
sense = self.s2v.get_best_sense(word) |
|
|
|
most_similar = self.s2v.most_similar(sense, n=20) |
|
|
|
distractors = self.removeDuplicates(most_similar, originalword) |
|
distractors.insert(0,originalword) |
|
|
|
answer_embedd, distractor_embedds = self.get_answer_and_distractor_embeddings(originalword,distractors) |
|
|
|
final_distractors = mmr(answer_embedd,distractor_embedds,distractors,5) |
|
filtered_distractors = [] |
|
|
|
for dist in final_distractors: |
|
filtered_distractors.append(dist[0]) |
|
|
|
|
|
Filtered_Distractors = filtered_distractors[1:] |
|
return Filtered_Distractors |
|
except: |
|
return [] |
|
|