from transformers import AutoModelForMaskedLM from transformers import AutoTokenizer from sklearn.metrics.pairwise import cosine_similarity import streamlit as st import torch import pickle import numpy as np import itertools model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022" model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) text = st.text_input("Enter word or key-phrase") exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')") exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')") k = st.number_input("Top k nearest key-phrases",1,10,5) with st.sidebar: diversify_box = st.checkbox("Diversify results",True) if diversify_box: k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20) #load kp dict with open("kp_dict_merged.pickle",'rb') as handle: kp_dict = pickle.load(handle) keys = list(kp_dict.keys()) for key in kp_dict.keys(): kp_dict[key] = kp_dict[key].detach().numpy() #load cosine distances of kp dict with open("cosine_kp.pickle",'rb') as handle: cosine_kp = pickle.load(handle) def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5): sim_dict = {} pools = pool_embeddings(out, tokens).detach().numpy() for key in kp_dict.keys(): if key == text: continue if exclude_text and text in key: continue if exclude_words and True in [x in key for x in text.split(" ")]: continue sim_dict[key] = cosine_similarity( pools, [kp_dict[key]] )[0][0] sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k] return {x:y for x,y in sims} def concat_tokens(sentences): tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []} for sentence in sentences: # encode each sentence and append to dictionary new_tokens = tokenizer.encode_plus(sentence, max_length=64, truncation=True, padding='max_length', return_tensors='pt') tokens['input_ids'].append(new_tokens['input_ids'][0]) tokens['attention_mask'].append(new_tokens['attention_mask'][0]) tokens['KPS'].append(sentence) # reformat list of tensors into single tensor tokens['input_ids'] = torch.stack(tokens['input_ids']) tokens['attention_mask'] = torch.stack(tokens['attention_mask']) return tokens def pool_embeddings(out, tok): embeddings = out["hidden_states"][-1] attention_mask = tok['attention_mask'] mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() masked_embeddings = embeddings * mask summed = torch.sum(masked_embeddings, 1) summed_mask = torch.clamp(mask.sum(1), min=1e-9) mean_pooled = summed / summed_mask return mean_pooled def extract_idxs(top_dict, kp_dict): idxs = [] c = 0 for i in list(kp_dict.keys()): if i in top_dict.keys(): idxs.append(c) c+=1 return idxs if text: text = text.lower() new_tokens = concat_tokens([text]) new_tokens.pop("KPS") with torch.no_grad(): outputs = model(**new_tokens) if not diversify_box: sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k) st.json(sim_dict) else: sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify) idxs = extract_idxs(sim_dict, kp_dict) distances_candidates = cosine_kp[np.ix_(idxs, idxs)] min_sim = np.inf candidate = None for combination in itertools.combinations(range(len(idxs)), k): sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate = combination min_sim = sim ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate} ret = sorted(ret.items(), key= lambda x: x[1], reverse = True) ret = {x:y for x,y in ret} st.json(ret)