from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification from transformers import AutoTokenizer from sklearn.metrics.pairwise import cosine_similarity import streamlit as st import torch import pickle import numpy as np import itertools model = AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022") kp_dict_checkpoint = "kp_dict_merged.pickle" kp_cosine_checkpoint = "cosine_kp.pickle" model_finbert = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True) tokenizer_finbert = AutoTokenizer.from_pretrained("ProsusAI/finbert") kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle" kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle" text = st.text_input("Enter word or key-phrase") exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query") exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')") k = st.number_input("Top k nearest key-phrases",1,10,5) with st.sidebar: diversify_box = st.checkbox("Diversify results",True) if diversify_box: k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20) #columns col1, col2 = st.columns(2) #load kp dicts with open(kp_dict_checkpoint,'rb') as handle: kp_dict = pickle.load(handle) keys = list(kp_dict.keys()) with open(kp_dict_finbert_checkpoint,'rb') as handle: kp_dict_finbert = pickle.load(handle) keys_finbert = list(kp_dict_finbert.keys()) #load cosine distances of kp dict with open(kp_cosine_checkpoint,'rb') as handle: cosine_kp = pickle.load(handle) with open(kp_cosine_finbert_checkpoint,'rb') as handle: cosine_finbert_kp = pickle.load(handle) def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5): sim_dict = {} pools = pool_embeddings(out, tokens).detach().numpy() for key in kp_dict.keys(): if key == text: continue if exclude_text and text in key: continue if exclude_words and True in [x in key for x in text.split(" ")]: continue sim_dict[key] = cosine_similarity( pools, [kp_dict[key]] )[0][0] sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k] return {x:y for x,y in sims} def concat_tokens(sentences, tokenizer): tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []} for sentence in sentences: # encode each sentence and append to dictionary new_tokens = tokenizer.encode_plus(sentence, max_length=64, truncation=True, padding='max_length', return_tensors='pt') tokens['input_ids'].append(new_tokens['input_ids'][0]) tokens['attention_mask'].append(new_tokens['attention_mask'][0]) tokens['KPS'].append(sentence) # reformat list of tensors into single tensor tokens['input_ids'] = torch.stack(tokens['input_ids']) tokens['attention_mask'] = torch.stack(tokens['attention_mask']) return tokens def pool_embeddings(out, tok): embeddings = out["hidden_states"][-1] attention_mask = tok['attention_mask'] mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() masked_embeddings = embeddings * mask summed = torch.sum(masked_embeddings, 1) summed_mask = torch.clamp(mask.sum(1), min=1e-9) mean_pooled = summed / summed_mask return mean_pooled def extract_idxs(top_dict, kp_dict): idxs = [] c = 0 for i in list(kp_dict.keys()): if i in top_dict.keys(): idxs.append(c) c+=1 return idxs if text: text = text.lower() new_tokens = concat_tokens([text], tokenizer) new_tokens.pop("KPS") new_tokens_finbert = concat_tokens([text], tokenizer_finbert) new_tokens_finbert.pop("KPS") with torch.no_grad(): outputs = model(**new_tokens) outputs_finbert = model_finbert(**new_tokens_finbert) sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k) sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k) if not diversify_box: with col1: st.write("distilbert-cvent") st.json(sim_dict) with col2: st.write("finbert") st.json(sim_dict_finbert) else: idxs = extract_idxs(sim_dict, kp_dict) idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert) distances_candidates = cosine_kp[np.ix_(idxs, idxs)] distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)] #first do distilbert candidate = None min_sim = np.inf for combination in itertools.combinations(range(len(idxs)), k): sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate = combination min_sim = sim #then do finbert candidate_finbert = None min_sim = np.inf for combination in itertools.combinations(range(len(idxs_finbert)), k): sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate_finbert = combination min_sim = sim #distilbert ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate} ret = sorted(ret.items(), key= lambda x: x[1], reverse = True) ret = {x:y for x,y in ret} #finbert ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert} candidate_finbert = sorted(candidate_finbert.items(), key= lambda x: x[1], reverse = True) candidate_finbert = {x:y for x,y in candidate_finbert} with col1: st.write("distilbert-cvent") st.json(ret) with col2: st.write("finbert") st.json(ret_finbert)