Thomas Müller commited on
Commit
2e3ce14
2 Parent(s): 10b1ab8 0f0c375

Merge branch 'main' of https://huggingface.co/symanto/mpnet-base-snli-mnli into main

Browse files
Files changed (1) hide show
  1. README.md +12 -2
README.md CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  This is a small cross attention entailment model trained for zero-shot and few-shot text classification experiments.
2
 
3
  The base model is [mpnet-base](https://huggingface.co/microsoft/mpnet-base) and it has been trained with the code from [here](https://github.com/facebookresearch/anli).
@@ -14,9 +23,10 @@ import numpy as np
14
  model = AutoModelForSequenceClassification.from_pretrained("symanto/mpnet-base-snli-mnli")
15
  tokenizer = AutoTokenizer.from_pretrained("symanto/mpnet-base-snli-mnli")
16
 
17
- inputs = tokenizer(["I like this pizza. [SEP] The sentence is positive.", "I like this pizza. [SEP] The sentence is negative."], return_tensors="pt")
 
18
  logits = model(**inputs).logits
19
  probs = torch.softmax(logits, dim=1).tolist()
20
  print("probs", probs)
21
- np.testing.assert_almost_equal(probs, [[0.66, 0.33, 0.01], [0.08, 0.22, 0.70]], decimal=2)
22
  ```
 
1
+ ---
2
+ language:
3
+ - en
4
+ datasets:
5
+ - SNLI
6
+ - MNLI
7
+ ---
8
+
9
+
10
  This is a small cross attention entailment model trained for zero-shot and few-shot text classification experiments.
11
 
12
  The base model is [mpnet-base](https://huggingface.co/microsoft/mpnet-base) and it has been trained with the code from [here](https://github.com/facebookresearch/anli).
 
23
  model = AutoModelForSequenceClassification.from_pretrained("symanto/mpnet-base-snli-mnli")
24
  tokenizer = AutoTokenizer.from_pretrained("symanto/mpnet-base-snli-mnli")
25
 
26
+ input_pairs = [("I like this pizza.", "The sentence is positive."), ("I like this pizza.", "The sentence is negative.")]
27
+ inputs = tokenizer(["</s></s>".join(input_pair) for input_pair in input_pairs], return_tensors="pt")
28
  logits = model(**inputs).logits
29
  probs = torch.softmax(logits, dim=1).tolist()
30
  print("probs", probs)
31
+ np.testing.assert_almost_equal(probs, [[0.86, 0.14, 0.00], [0.16, 0.15, 0.69]], decimal=2)
32
  ```