Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
|
7 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
|
8 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
|
|
|
9 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
10 |
text1 = st.text_area("Enter first sentence")
|
11 |
text2 = st.text_area("Enter second sentence")
|
@@ -37,10 +38,21 @@ def pool_embeddings(out, tok):
|
|
37 |
return mean_pooled
|
38 |
|
39 |
if text1 and text2:
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-imdb-accelerate"
|
8 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint,output_hidden_states=True)
|
9 |
+
model_base = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
11 |
text1 = st.text_area("Enter first sentence")
|
12 |
text2 = st.text_area("Enter second sentence")
|
|
|
38 |
return mean_pooled
|
39 |
|
40 |
if text1 and text2:
|
41 |
+
with torch.no_grad():
|
42 |
+
tokens = concat_tokens(text1,text2)
|
43 |
+
outputs = model(**tokens)
|
44 |
+
mean_pooled = pool_embeddings(outputs,tokens).detach().numpy()
|
45 |
+
fine_tuned_out = cosine_similarity(
|
46 |
+
[mean_pooled[0]],
|
47 |
+
mean_pooled[1:]
|
48 |
+
)[0]
|
49 |
+
|
50 |
+
outputs_base = model_base(**tokens)
|
51 |
+
mean_pooled_base = pool_embeddings(outputs_base,tokens).detach().numpy()
|
52 |
+
base_out = cosine_similarity(
|
53 |
+
[mean_pooled[0]],
|
54 |
+
mean_pooled[1:]
|
55 |
+
)[0]
|
56 |
+
|
57 |
+
st.write(f">>>Similarity for fine-tuned {fine_tuned_out}")
|
58 |
+
st.write(f">>>Similarity for base {base_out}")
|