MoritzLaurer HF staff commited on
Commit
a708ccc
1 Parent(s): 2e3cc95

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -0
README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - text-classification
6
+ - zero-shot-classification
7
+ metrics:
8
+ - accuracy
9
+ widget:
10
+ - text: "70-85% of the population needs to get vaccinated against the novel coronavirus to achieve herd immunity."
11
+
12
+ ---
13
+ # DeBERTa-v3-base-mnli-fever-anli
14
+ ## Model description
15
+ This model was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs. This base model outperforms almost all large models on the [ANLI benchmark](https://github.com/facebookresearch/anli).
16
+ The base model is [DeBERTa-v3-base from Microsoft](https://huggingface.co/microsoft/deberta-v3-base). The v3 variant substantially outperforms previous versions of the model by including a different pre-training objective, see annex 11 of the original [DeBERTa paper](https://arxiv.org/pdf/2006.03654.pdf).
17
+
18
+ ## Intended uses & limitations
19
+ #### How to use the model
20
+ ```python
21
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
22
+ import torch
23
+ model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
26
+ text = "The new variant first detected in southern England in September is blamed for sharp rises in levels of positive tests in recent weeks in London, south-east England and the east of England"
27
+ input = tokenizer(text, truncation=True, return_tensors="pt")
28
+ output = model(input["input_ids"])
29
+ prediction = torch.softmax(output["logits"][0], -1).tolist()
30
+ label_names = ["entailment", "neutral", "contradiction"]
31
+ prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
32
+ print(prediction)
33
+ ```
34
+
35
+ ### Training data
36
+ DeBERTa-v3-base-mnli-fever-anli was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs.
37
+
38
+ ### Training procedure
39
+ DeBERTa-v3-base-mnli-fever-anli was trained using the Hugging Face trainer with the following hyperparameters.
40
+ ```
41
+ training_args = TrainingArguments(
42
+ num_train_epochs=3, # total number of training epochs
43
+ learning_rate=2e-05,
44
+ per_device_train_batch_size=32, # batch size per device during training
45
+ per_device_eval_batch_size=32, # batch size for evaluation
46
+ warmup_ratio=0.1, # number of warmup steps for learning rate scheduler
47
+ weight_decay=0.06, # strength of weight decay
48
+ fp16=True # mixed precision training
49
+ )
50
+ ```
51
+
52
+ ### Eval results
53
+ The model was evaluated using the test sets for MultiNLI and ANLI and the dev set for Fever-NLI
54
+ dataset | accuracy
55
+ -------|---------
56
+ mnli_m/mm | 0.903/0.903
57
+ fever-nli | 0.777
58
+ anli-all | 0.579
59
+ anli-r3 | 0.495
60
+
61
+ ## Limitations and bias
62
+ Please consult the original DeBERTa paper and literature on different NLI datasets for potential biases.
63
+
64
+ ### BibTeX entry and citation info
65
+ ```bibtex
66
+ @unpublished{
67
+ title={DeBERTa-v3-base-mnli-fever-anli},
68
+ author={Moritz Laurer},
69
+ year={2021},
70
+ note={Unpublished paper}
71
+ }
72
+ ```