vives commited on
Commit
3bba4cb
1 Parent(s): 77b655a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
 
7
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
8
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
10
  text1 = st.text_area("Enter first sentence")
11
  text2 = st.text_area("Enter second sentence")
@@ -37,10 +38,21 @@ def pool_embeddings(out, tok):
37
  return mean_pooled
38
 
39
  if text1 and text2:
40
- tokens = concat_tokens(text1,text2)
41
- outputs = model(**tokens)
42
- mean_pooled = pool_embeddings(outputs,tokens).detach().numpy()
43
- st.write(cosine_similarity(
44
- [mean_pooled[0]],
45
- mean_pooled[1:]
46
- ))
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
8
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
9
+ model_base = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
10
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
11
  text1 = st.text_area("Enter first sentence")
12
  text2 = st.text_area("Enter second sentence")
 
38
  return mean_pooled
39
 
40
  if text1 and text2:
41
+ with torch.no_grad():
42
+ tokens = concat_tokens(text1,text2)
43
+ outputs = model(**tokens)
44
+ mean_pooled = pool_embeddings(outputs,tokens).detach().numpy()
45
+ fine_tuned_out = cosine_similarity(
46
+ [mean_pooled[0]],
47
+ mean_pooled[1:]
48
+ )[0]
49
+
50
+ outputs_base = model_base(**tokens)
51
+ mean_pooled_base = pool_embeddings(outputs_base,tokens).detach().numpy()
52
+ base_out = cosine_similarity(
53
+ [mean_pooled[0]],
54
+ mean_pooled[1:]
55
+ )[0]
56
+
57
+ st.write(f">>>Similarity for fine-tuned {fine_tuned_out}")
58
+ st.write(f">>>Similarity for base {base_out}")