vives commited on
Commit
ed162b2
1 Parent(s): 5012338

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForMaskedLM
2
+ from transformers import AutoTokenizer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ import streamlit as st
5
+ import torch
6
+ import pickle
7
+
8
+ model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
9
+ model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
11
+ text = st.text_input("Enter word or key-phrase")
12
+ exclude_text = st.radio("exclude_text",[True,False])
13
+ exclude_words = st.radio("exclude_words",[True,False])
14
+ k = st.number_input("Top k nearest key-phrases",1,10)
15
+
16
+ with open("kp_dict_merged.pickle",'rb') as handle:
17
+ kp_dict = pickle.load(handle)
18
+
19
+ def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
20
+ sim_dict = {}
21
+ pools = pool_embeddings(out, tokens).detach().numpy()
22
+ for key in kp_dict.keys():
23
+ if key == text:
24
+ continue
25
+ if exclude_text and text in key:
26
+ continue
27
+ if exclude_words and True in [x in key for x in text.split(" ")]:
28
+ continue
29
+ sim_dict[key] = cosine_similarity(
30
+ pools,
31
+ [kp_dict[key]]
32
+ )[0][0]
33
+ sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
34
+ return {x:y for x,y in sims}
35
+ def concat_tokens(sentences):
36
+ tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
37
+ for sentence in sentences:
38
+ # encode each sentence and append to dictionary
39
+ new_tokens = tokenizer.encode_plus(sentence, max_length=64,
40
+ truncation=True, padding='max_length',
41
+ return_tensors='pt')
42
+ tokens['input_ids'].append(new_tokens['input_ids'][0])
43
+ tokens['attention_mask'].append(new_tokens['attention_mask'][0])
44
+ tokens['KPS'].append(sentence)
45
+ # reformat list of tensors into single tensor
46
+ tokens['input_ids'] = torch.stack(tokens['input_ids'])
47
+ tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
48
+ return tokens
49
+
50
+ def pool_embeddings(out, tok):
51
+ embeddings = out["hidden_states"][-1]
52
+ attention_mask = tok['attention_mask']
53
+ mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
54
+ masked_embeddings = embeddings * mask
55
+ summed = torch.sum(masked_embeddings, 1)
56
+ summed_mask = torch.clamp(mask.sum(1), min=1e-9)
57
+ mean_pooled = summed / summed_mask
58
+ return mean_pooled
59
+
60
+ if text:
61
+ new_tokens = concat_tokens([text])
62
+ new_tokens.pop("KPS")
63
+ with torch.no_grad():
64
+ outputs = model(**new_tokens)
65
+ sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
66
+ st.json(sim_dict)