muelletm commited on
Commit
0f0c375
1 Parent(s): f7ccce6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -23,9 +23,10 @@ import numpy as np
23
  model = AutoModelForSequenceClassification.from_pretrained("symanto/mpnet-base-snli-mnli")
24
  tokenizer = AutoTokenizer.from_pretrained("symanto/mpnet-base-snli-mnli")
25
 
26
- inputs = tokenizer(["I like this pizza. [SEP] The sentence is positive.", "I like this pizza. [SEP] The sentence is negative."], return_tensors="pt")
 
27
  logits = model(**inputs).logits
28
  probs = torch.softmax(logits, dim=1).tolist()
29
  print("probs", probs)
30
- np.testing.assert_almost_equal(probs, [[0.66, 0.33, 0.01], [0.08, 0.22, 0.70]], decimal=2)
31
  ```
 
23
  model = AutoModelForSequenceClassification.from_pretrained("symanto/mpnet-base-snli-mnli")
24
  tokenizer = AutoTokenizer.from_pretrained("symanto/mpnet-base-snli-mnli")
25
 
26
+ input_pairs = [("I like this pizza.", "The sentence is positive."), ("I like this pizza.", "The sentence is negative.")]
27
+ inputs = tokenizer(["</s></s>".join(input_pair) for input_pair in input_pairs], return_tensors="pt")
28
  logits = model(**inputs).logits
29
  probs = torch.softmax(logits, dim=1).tolist()
30
  print("probs", probs)
31
+ np.testing.assert_almost_equal(probs, [[0.86, 0.14, 0.00], [0.16, 0.15, 0.69]], decimal=2)
32
  ```