vives commited on
Commit
01714af
1 Parent(s): 8f5d8f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -4
app.py CHANGED
@@ -1,9 +1,45 @@
1
  from transformers import AutoModelForMaskedLM
2
  from transformers import AutoTokenizer
 
3
  import streamlit as st
 
4
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
5
- #model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
6
- text = st.text_area("enter some text!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- if text:
9
- st.write(type(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoModelForMaskedLM
2
  from transformers import AutoTokenizer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
  import streamlit as st
5
+
6
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
7
+ model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
9
+ text1 = st.text_area("Enter first sentence")
10
+ text2 = st.text_area("Enter second sentence")
11
+
12
+ def concat_tokens(t1,t2):
13
+ tokens = {'input_ids': [], 'attention_mask': []}
14
+ sentences = [t1, t2]
15
+ for sentence in sentences:
16
+ # encode each sentence and append to dictionary
17
+ new_tokens = tokenizer.encode_plus(sentence, max_length=128,
18
+ truncation=True, padding='max_length',
19
+ return_tensors='pt')
20
+ tokens['input_ids'].append(new_tokens['input_ids'][0])
21
+ tokens['attention_mask'].append(new_tokens['attention_mask'][0])
22
+
23
+ # reformat list of tensors into single tensor
24
+ tokens['input_ids'] = torch.stack(tokens['input_ids'])
25
+ tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
26
+ return tokens
27
 
28
+ def pool_embeddings(out, tok):
29
+ embeddings = out["hidden_states"][-1]
30
+ attention_mask = tok['attention_mask']
31
+ mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
32
+ masked_embeddings = embeddings * mask
33
+ summed = torch.sum(masked_embeddings, 1)
34
+ summed_mask = torch.clamp(mask.sum(1), min=1e-9)
35
+ mean_pooled = summed / summed_mask
36
+ return mean_pooled
37
+
38
+ if text1 and text2:
39
+ tokens = concat_tokens(text1,text2)
40
+ outputs = model(**tokens)
41
+ mean_pooled = pool_embeddings(outputs,tokens).detach().numpy()
42
+ st.write(cosine_similarity(
43
+ [mean_pooled[0]],
44
+ mean_pooled[1:]
45
+ ))