from sense2vec import Sense2Vec from sentence_transformers import SentenceTransformer import wget import os from .mmr import mmr url = 'https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz' cmd = 'tar -xvf {}' class S2V: def __init__(self): self.model= SentenceTransformer('all-MiniLM-L12-v2') filename = wget.download(url) os.system(cmd.format(filename)) self.s2v = Sense2Vec().from_disk('s2v_old') def removeDuplicates(self, most_similar, originalword): distractors = [] #remove duplicates for each_word in most_similar: append_word = each_word[0].split("|")[0].replace("_", " ") if append_word not in distractors and append_word != originalword: distractors.append(append_word) return distractors def get_answer_and_distractor_embeddings(self,answer,candidate_distractors): answer_embedding = self.model.encode([answer]) distractor_embeddings = self.model.encode(candidate_distractors) return answer_embedding,distractor_embeddings def execute(self, originalword): try: word = originalword.lower() word = word.replace(" ", "_") # Find the best-matching sense for a given word based on the available senses and frequency counts. sense = self.s2v.get_best_sense(word) # Get the most similar entries in the table most_similar = self.s2v.most_similar(sense, n=20) #remove duplicates distractors = self.removeDuplicates(most_similar, originalword) distractors.insert(0,originalword) # encode distractors and answer answer_embedd, distractor_embedds = self.get_answer_and_distractor_embeddings(originalword,distractors) #Maximal Marginal Relevance origin: https://maartengr.github.io/KeyBERT/api/mmr.html final_distractors = mmr(answer_embedd,distractor_embedds,distractors,5) filtered_distractors = [] for dist in final_distractors: filtered_distractors.append(dist[0]) #Answer = filtered_distractors[0] Filtered_Distractors = filtered_distractors[1:] return Filtered_Distractors except: return []