from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification, AutoModel from transformers import AutoTokenizer from sklearn.metrics.pairwise import cosine_similarity import streamlit as st import torch import pickle import numpy as np import itertools import tokenizers @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForMaskedLM: lambda _: None}) def load_bert(): return (AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True), AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022")) model, tokenizer = load_bert() kp_dict_checkpoint = "kp_dict_merged.pickle" kp_cosine_checkpoint = "cosine_kp.pickle" @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForSequenceClassification: lambda _: None}) def load_finbert(): return (AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True), AutoTokenizer.from_pretrained("ProsusAI/finbert")) model_finbert, tokenizer_finbert = load_finbert() kp_dict_finbert_checkpoint = "kp_dict_finance.pickle" kp_cosine_finbert_checkpoint = "cosine_kp_finance.pickle" @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModel: lambda _: None}) def load_sapbert(): return (AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True), AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")) model_sapbert, tokenizer_sapbert = load_sapbert() kp_dict_sapbert_checkpoint = "kp_dict_medical.pickle" kp_cosine_sapbert_checkpoint = "cosine_kp_medical.pickle" text = st.text_input("Enter word or key-phrase") exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query") exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')") k = st.number_input("Top k nearest key-phrases",1,10,5) with st.sidebar: diversify_box = st.checkbox("Diversify results",True) if diversify_box: k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20) #columns col1, col2, col3 = st.columns(3) #load kp dicts with open(kp_dict_checkpoint,'rb') as handle: kp_dict = pickle.load(handle) keys = list(kp_dict.keys()) with open(kp_dict_finbert_checkpoint,'rb') as handle: kp_dict_finbert = pickle.load(handle) keys_finbert = list(kp_dict_finbert.keys()) with open(kp_dict_sapbert_checkpoint,'rb') as handle: kp_dict_sapbert = pickle.load(handle) keys_sapbert = list(kp_dict_sapbert.keys()) #load cosine distances of kp dict with open(kp_cosine_checkpoint,'rb') as handle: cosine_kp = pickle.load(handle) with open(kp_cosine_finbert_checkpoint,'rb') as handle: cosine_finbert_kp = pickle.load(handle) with open(kp_cosine_sapbert_checkpoint,'rb') as handle: cosine_sapbert_kp = pickle.load(handle) def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5, pooler=True): sim_dict = {} if pooler: pools = pool_embeddings(out, tokens).detach().numpy() else: pools = out["pooler_output"].detach().numpy() for key in kp_dict.keys(): if key == text: continue if exclude_text and text in key: continue if exclude_words and True in [x in key for x in text.split(" ")]: continue sim_dict[key] = cosine_similarity( pools, [kp_dict[key]] )[0][0] sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k] return {x:y for x,y in sims} def concat_tokens(sentences, tokenizer): tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []} for sentence in sentences: # encode each sentence and append to dictionary new_tokens = tokenizer.encode_plus(sentence, max_length=64, truncation=True, padding='max_length', return_tensors='pt') tokens['input_ids'].append(new_tokens['input_ids'][0]) tokens['attention_mask'].append(new_tokens['attention_mask'][0]) tokens['KPS'].append(sentence) # reformat list of tensors into single tensor tokens['input_ids'] = torch.stack(tokens['input_ids']) tokens['attention_mask'] = torch.stack(tokens['attention_mask']) return tokens def pool_embeddings(out, tok): embeddings = out["hidden_states"][-1] attention_mask = tok['attention_mask'] mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() masked_embeddings = embeddings * mask summed = torch.sum(masked_embeddings, 1) summed_mask = torch.clamp(mask.sum(1), min=1e-9) mean_pooled = summed / summed_mask return mean_pooled def extract_idxs(top_dict, kp_dict): idxs = [] c = 0 for i in list(kp_dict.keys()): if i in top_dict.keys(): idxs.append(c) c+=1 return idxs if text: text = text.lower() new_tokens = concat_tokens([text], tokenizer) new_tokens.pop("KPS") new_tokens_finbert = concat_tokens([text], tokenizer_finbert) new_tokens_finbert.pop("KPS") new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert) new_tokens_sapbert.pop("KPS") with torch.no_grad(): outputs = model(**new_tokens) outputs_finbert = model_finbert(**new_tokens_finbert) outputs_sapbert = model_sapbert(**new_tokens_sapbert) if not diversify_box: sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify) sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify) sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify, pooler=False) with col1: st.write("distilbert-cvent") st.json(sim_dict) with col2: st.write("finbert") st.json(sim_dict_finbert) with col3: st.write("sapbert") st.json(sim_dict_sapbert) else: sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k) sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k) sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k, pooler=False) idxs = extract_idxs(sim_dict, kp_dict) idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert) idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert) distances_candidates = cosine_kp[np.ix_(idxs, idxs)] distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)] distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)] #first do distilbert candidate = None min_sim = np.inf for combination in itertools.combinations(range(len(idxs)), k): sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate = combination min_sim = sim #then do finbert candidate_finbert = None min_sim = np.inf for combination in itertools.combinations(range(len(idxs_finbert)), k): sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate_finbert = combination min_sim = sim #sapbert candidate_sapbert = None min_sim = np.inf for combination in itertools.combinations(range(len(idxs_sapbert)), k): sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j]) if sim < min_sim: candidate_sapbert = combination min_sim = sim #distilbert ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate} ret = sorted(ret.items(), key= lambda x: x[1], reverse = True) ret = {x:y for x,y in ret} #finbert ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert} ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True) ret_finbert = {x:y for x,y in ret_finbert} #sapbert ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert} ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True) ret_sapbert = {x:y for x,y in ret_sapbert} with col1: st.write("distilbert-cvent") st.json(ret) with col2: st.write("finbert") st.json(ret_finbert) with col3: st.write("sapbert") st.json(ret_sapbert)