mjwong commited on
Commit
d5da251
1 Parent(s): 4dd79cf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -38,6 +38,8 @@ Gautier Izacard, Mathilde Caron, Lucas Hosseini, Sebastian Riedel, Piotr Bojanow
38
 
39
  ## How to use the model
40
 
 
 
41
  The model can be loaded with the `zero-shot-classification` pipeline like so:
42
 
43
  ```python
@@ -61,6 +63,32 @@ candidate_labels = ["politics", "economy", "entertainment", "environment"]
61
  classifier(sequence_to_classify, candidate_labels, multi_label=True)
62
  ```
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  ### Eval results
65
  The model was evaluated using the XNLI test sets on 14 languages: English (en), Arabic (ar), Bulgarian (bg), German (de), Greek (el), Spanish (es), French (fr), Russian (ru), Swahili (sw), Thai (th), Turkish (tr), Urdu (ur), Vietnam (vi) and Chinese (zh). The metric used is accuracy.
66
 
 
38
 
39
  ## How to use the model
40
 
41
+ ### With the zero-shot classification pipeline
42
+
43
  The model can be loaded with the `zero-shot-classification` pipeline like so:
44
 
45
  ```python
 
63
  classifier(sequence_to_classify, candidate_labels, multi_label=True)
64
  ```
65
 
66
+ ### With manual PyTorch
67
+
68
+ The model can also be applied on NLI tasks like so:
69
+
70
+ ```python
71
+ import torch
72
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
73
+
74
+ # device = "cuda:0" or "cpu"
75
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
76
+
77
+ model_name = "mjwong/mcontriever-msmarco-xnli"
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
79
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
80
+
81
+ premise = "But I thought you'd sworn off coffee."
82
+ hypothesis = "I thought that you vowed to drink more coffee."
83
+
84
+ input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
85
+ output = model(input["input_ids"].to(device))
86
+ prediction = torch.softmax(output["logits"][0], -1).tolist()
87
+ label_names = ["entailment", "neutral", "contradiction"]
88
+ prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
89
+ print(prediction)
90
+ ```
91
+
92
  ### Eval results
93
  The model was evaluated using the XNLI test sets on 14 languages: English (en), Arabic (ar), Bulgarian (bg), German (de), Greek (el), Spanish (es), French (fr), Russian (ru), Swahili (sw), Thai (th), Turkish (tr), Urdu (ur), Vietnam (vi) and Chinese (zh). The metric used is accuracy.
94