lannelin commited on
Commit
f21d146
1 Parent(s): af59f6c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +169 -1
README.md CHANGED
@@ -1 +1,169 @@
1
- coming soon!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: multilingual
3
+ tags:
4
+ - pytorch
5
+ license: apache-2.0
6
+ datasets:
7
+ - multi_nli
8
+ - xnli
9
+ metrics:
10
+ - xnli
11
+ widget:
12
+ - text: "xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política."
13
+
14
+ ---
15
+
16
+ # mt5-large-finetuned-mnli-xtreme-xnli
17
+
18
+ ## Model Description
19
+
20
+
21
+ This model takes a pretrained large [multilingual-t5](https://github.com/google-research/multilingual-t5) (also available from [models](https://huggingface.co/google/mt5-large)) and fine-tunes it on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set. It is intended to be used for zero-shot text classification, inspired by [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli).
22
+
23
+ ## Intended Use
24
+
25
+ This model is intended to be used for zero-shot text classification, especially in languages other than English. It is fine-tuned on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set, a multilingual NLI dataset. The model can therefore be used with any of the languages in the XNLI corpus:
26
+
27
+ - Arabic
28
+ - Bulgarian
29
+ - Chinese
30
+ - English
31
+ - French
32
+ - German
33
+ - Greek
34
+ - Hindi
35
+ - Russian
36
+ - Spanish
37
+ - Swahili
38
+ - Thai
39
+ - Turkish
40
+ - Urdu
41
+ - Vietnamese
42
+
43
+
44
+ As per recommendations in [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli), for English-only classification, you might want to check out:
45
+ - [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli)
46
+ - [a distilled bart MNLI model](https://huggingface.co/models?filter=pipeline_tag%3Azero-shot-classification&search=valhalla).
47
+
48
+
49
+ ### Zero-shot example:
50
+
51
+ The model retains its text-to-text characteristic after fine-tuning. This means that our expected outputs will be text. During fine-tuning, the model learns to respond to the NLI task with a series of single token responses that map to entailment, neutral, or contradiction. The NLI task is indicated with a fixed prefix, "xnli:".
52
+
53
+ Below is an example, using PyTorch, of the model's use in a similar fashion to the `zero-shot-classification` pipeline. We use the logits from the LM output at the first token to represent confidence.
54
+
55
+ ```python
56
+ from torch.nn.functional import softmax
57
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
58
+
59
+ model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
60
+
61
+ tokenizer = MT5Tokenizer.from_pretrained(model_name)
62
+ model = MT5ForConditionalGeneration.from_pretrained(model_name)
63
+ model.eval()
64
+
65
+ sequence_to_classify = "¿A quién vas a votar en 2020?"
66
+ candidate_labels = ["Europa", "salud pública", "política"]
67
+ hypothesis_template = "Este ejemplo es {}."
68
+
69
+ ENTAILS_LABEL = "▁0"
70
+ NEUTRAL_LABEL = "▁1"
71
+ CONTRADICTS_LABEL = "▁2"
72
+
73
+ label_inds = tokenizer.convert_tokens_to_ids(
74
+ [ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
75
+
76
+
77
+ def process_nli(premise: str, hypothesis: str):
78
+ """ process to required xnli format with task prefix """
79
+ return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
80
+
81
+
82
+ # construct sequence of premise, hypothesis pairs
83
+ seqs = [(sequence_to_classify, hypothesis_template.format(label)) for label in
84
+ candidate_labels]
85
+ # format for mt5 xnli task
86
+ seqs = [process_nli(premise=premise, hypothesis=hypothesis) for
87
+ premise, hypothesis in seqs]
88
+ print(seqs)
89
+ # ['xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es Europa.',
90
+ # 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es salud pública.',
91
+ # 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política.']
92
+
93
+ inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
94
+
95
+ out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
96
+ num_beams=1)
97
+
98
+ # sanity check that our sequences are expected length (1 + start token + end token = 3)
99
+ for i, seq in enumerate(out.sequences):
100
+ assert len(
101
+ seq) == 3, f"generated sequence {i} not of expected length, 3." \
102
+ f" Actual length: {len(seq)}"
103
+
104
+ # get the scores for our only token of interest
105
+ # we'll now treat these like the output logits of a `*ForSequenceClassification` model
106
+ scores = out.scores[0]
107
+
108
+ # scores has a size of the model's vocab.
109
+ # However, for this task we have a fixed set of labels
110
+ # sanity check that these labels are always the top 3 scoring
111
+ for i, sequence_scores in enumerate(scores):
112
+ top_scores = sequence_scores.argsort()[-3:]
113
+ assert set(top_scores.tolist()) == set(label_inds), \
114
+ f"top scoring tokens are not expected for this task." \
115
+ f" Expected: {label_inds}. Got: {top_scores.tolist()}."
116
+
117
+ # cut down scores to our task labels
118
+ scores = scores[:, label_inds]
119
+ print(scores)
120
+ # tensor([[-2.5697, 1.0618, 0.2088],
121
+ # [-5.4492, -2.1805, -0.1473],
122
+ # [ 2.2973, 3.7595, -0.1769]])
123
+
124
+
125
+ # new indices of entailment and contradiction in scores
126
+ entailment_ind = 0
127
+ contradiction_ind = 2
128
+
129
+ # we can show, per item, the entailment vs contradiction probas
130
+ entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
131
+ entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
132
+ print(entail_vs_contra_probas)
133
+ # tensor([[0.0585, 0.9415],
134
+ # [0.0050, 0.9950],
135
+ # [0.9223, 0.0777]])
136
+
137
+
138
+ # or we can show probas similar to `ZeroShotClassificationPipeline`
139
+ # this gives a zero-shot classification style output across labels
140
+ entail_scores = scores[:, 0]
141
+ entail_probas = softmax(entail_scores, dim=0)
142
+ print(entail_probas)
143
+ # tensor([7.6341e-03, 4.2873e-04, 9.9194e-01])
144
+
145
+ print(dict(zip(candidate_labels, entail_probas.tolist())))
146
+ # {'Europa': 0.007634134963154793,
147
+ # 'salud pública': 0.0004287279152777046,
148
+ # 'política': 0.9919371604919434}
149
+
150
+
151
+ ```
152
+
153
+ Unfortunately, the `generate` function for the TF equivalent model doesn't exactly mirror the PyTorch version so the above code won't directly transfer.
154
+
155
+ The model is currently not compatible with the existing `zero-shot-classification` pipeline.
156
+
157
+
158
+ ## Training
159
+
160
+ This model was pre-trained on a set of 101 languages in the mC4, as described in [the mt5 paper](https://arxiv.org/abs/2010.11934). It was then fine-tuned on the [mt5_xnli_translate_train](https://github.com/google-research/multilingual-t5/blob/78d102c830d76bd68f27596a97617e2db2bfc887/multilingual_t5/tasks.py#L190) task for 8k steps in a similar manner to that described in the [offical repo](https://github.com/google-research/multilingual-t5#fine-tuning), with guidance from [Stephen Mayhew's notebook](https://github.com/mayhewsw/multilingual-t5/blob/master/notebooks/mt5-xnli.ipynb). The resulting model was then converted to :hugging_face: format.
161
+
162
+
163
+ ## Eval results
164
+
165
+ Accuracy over XNLI test set:
166
+
167
+ | ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vi | zh | average |
168
+ |------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|
169
+ | 81.0 | 85.0 | 84.3 | 84.3 | 88.8 | 85.3 | 83.9 | 79.9 | 82.6 | 78.0 | 81.0 | 81.6 | 76.4 | 81.7 | 82.3 | 82.4 |