mjwong commited on
Commit
c45c1a2
1 Parent(s): 2f18ab7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -22,6 +22,8 @@ Liang Wang, Nan Yang, Xiaolong Huang, Binxing Jiao, Linjun Yang, Daxin Jiang, Ra
22
 
23
  ## How to use the model
24
 
 
 
25
  The model can be loaded with the `zero-shot-classification` pipeline like so:
26
 
27
  ```python
@@ -45,6 +47,32 @@ candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
45
  classifier(sequence_to_classify, candidate_labels, multi_class=True)
46
  ```
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ### Eval results
49
  The model was evaluated using the dev sets for MultiNLI and test sets for ANLI. The metric used is accuracy.
50
 
 
22
 
23
  ## How to use the model
24
 
25
+ ### With the zero-shot classification pipeline
26
+
27
  The model can be loaded with the `zero-shot-classification` pipeline like so:
28
 
29
  ```python
 
47
  classifier(sequence_to_classify, candidate_labels, multi_class=True)
48
  ```
49
 
50
+ ### With manual PyTorch
51
+
52
+ The model can also be applied on NLI tasks like so:
53
+
54
+ ```python
55
+ import torch
56
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
+
58
+ # device = "cuda:0" or "cpu"
59
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
60
+
61
+ model_name = "mjwong/e5-large-mnli-anli"
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
64
+
65
+ premise = "But I thought you'd sworn off coffee."
66
+ hypothesis = "I thought that you vowed to drink more coffee."
67
+
68
+ input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
69
+ output = model(input["input_ids"].to(device))
70
+ prediction = torch.softmax(output["logits"][0], -1).tolist()
71
+ label_names = ["entailment", "neutral", "contradiction"]
72
+ prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
73
+ print(prediction)
74
+ ```
75
+
76
  ### Eval results
77
  The model was evaluated using the dev sets for MultiNLI and test sets for ANLI. The metric used is accuracy.
78