File size: 6,329 Bytes
8b96879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73a0619
 
 
 
 
 
 
 
 
1245691
73a0619
 
 
 
d45b4c9
73a0619
 
 
1245691
73a0619
 
 
 
d45b4c9
73a0619
 
 
1245691
73a0619
 
 
 
d45b4c9
73a0619
 
 
1245691
d45b4c9
 
 
 
 
 
 
 
 
1245691
d45b4c9
 
 
 
 
 
 
 
 
1245691
3b22462
 
 
 
 
 
 
 
 
1245691
b6b9b0a
 
 
 
 
 
 
 
 
1245691
cf35b7c
 
 
 
 
 
 
 
1245691
d45b4c9
 
 
 
 
 
 
 
 
1245691
0ed5b88
 
 
 
 
 
fabbe38
0ed5b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fabbe38
 
 
 
 
 
 
 
 
1245691
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
---
license: mit
language:
- en
tags:
- t5
model-index:
- name: metro_t0_base
  results:
  - task:
      type: natural-language-inference
    dataset:
      type: super_glue
      name: RTE
      config: rte
      split: validation
    metrics:
      - type: accuracy
        value: 61.6245487364621
  - task:
      type: natural-language-inference
    dataset:
      type: super_glue
      name: CB
      config: cb
      split: validation
    metrics:
      - type: accuracy
        value: 52.73809523809525
  - task:
      type: natural-language-inference
    dataset:
      type: anli
      name: ANLI R1
      split: dev_r1
    metrics:
      - type: accuracy
        value: 31.706666666666667
  - task:
      type: natural-language-inference
    dataset:
      type: anli
      name: ANLI R2
      split: dev_r2
    metrics:
      - type: accuracy
        value: 33.486666666666665
  - task:
      type: natural-language-inference
    dataset:
      type: anli
      name: ANLI R3
      split: dev_r3
    metrics:
      - type: accuracy
        value: 33.44444444444444
  - task:
      type: coreference-resolution
    dataset:
      type: super_glue
      name: WSC
      config: wsc.fixed
      split: validation
    metrics:
      - type: accuracy
        value: 58.75
  - task:
      type: coreference-resolution
    dataset:
      type: winogrande
      name: Winogrande XL
      config: winogrande_xl
      split: validation
    metrics:
      - type: accuracy
        value: 50.95501183898973
  - task:
      type: multiple-choice-qa
    dataset:
      type: super_glue
      name: COPA
      config: copa
      split: validation
    metrics:
      - type: accuracy
        value: 66.25
  - task:
      type: multiple-choice-qa
    dataset:
      type: story_cloze
      name: StoryCloze 2016
      config: '2016'
      split: validation
    metrics:
      - type: accuracy
        value: 82.40513094601816
  - task:
      type: multiple-choice-qa
    dataset:
      type: hellaswag
      name: HellaSwag
      split: validation
    metrics:
      - type: accuracy
        value: 25.647281418044216
  - task:
      type: word-sense-disambiguation
    dataset:
      type: super_glue
      name: WiC
      config: wic
      split: validation
    metrics:
      - type: accuracy
        value: 50.423197492163006
---

Official repository: https://github.com/gonglinyuan/metro_t0

# METRO-T0

Paper: [Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers](https://arxiv.org/abs/2305.12567) (ACL 2023)

METRO-T0 is a T5-style text-to-text Transformer pretrained using model-generated pretraining signals, prompt-finetuned on a family of public NLP tasks proposed in [T0](https://arxiv.org/abs/2110.08207).
METRO-T0 is highly parameter efficient. For example, METRO-T0-Large++ (775M parameters) outperforms GPT-3 (175B parameters) and T0-3B (3B parameters) on a wide range of NLP tasks.

![The architecture of METRO-T0 during pretraining using BERT as the auxiliary model to generate signals](https://github.com/gonglinyuan/metro_t0/raw/main/assets/metro_t0_method.png)

![Prompt learning results of METRO-T0 versus our T0 baseline and T03B by Sanh et al. (2022) on 4 tasks  in the T0 Eval benchmark. Each point denotes the accuracy using one prompt template, except that the median accuracy over all templates of T03B is indicated by the blue point. The plots of other tasks are in our paper.](https://github.com/gonglinyuan/metro_t0/raw/main/assets/metro_t0_selected_results.png)

## Use METRO-T0-Base

To use METRO-T0-Base in PyTorch (Python 3.7+, PyTorch 1.12+ and transformers 4.17+ are prerequisites), refer to the code snippet below:

```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained("gonglinyuan/metro_t0_base", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("gonglinyuan/metro_t0_base", trust_remote_code=True)

input_text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
inputs = tokenizer([input_text], max_length=512, truncation=True, add_special_tokens=True, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))  # expected: positive
```

## Other METRO-T0 Models

|                    | # Parameters | Pretraining Data | Prompt-Finetuning Data |
|--------------------|--------------|------------------|------------------------|
| [METRO-T0-Base](https://huggingface.co/gonglinyuan/metro_t0_base)      | 226M         | Wikibook (16G)   | T0 Train               |
| [METRO-T0+-Base](https://huggingface.co/gonglinyuan/metro_t0p_base)     | 226M         | Wikibook (16G)   | T0+ Train              |
| [METRO-T0++-Base](https://huggingface.co/gonglinyuan/metro_t0pp_base)    | 226M         | Wikibook (16G)   | T0++ Train             |
| [METRO-T0-Base++](https://huggingface.co/gonglinyuan/metro_t0_basepp)    | 256M         | 160G corpus      | T0 Train               |
| [METRO-T0+-Base++](https://huggingface.co/gonglinyuan/metro_t0p_basepp)   | 256M         | 160G corpus      | T0+ Train              |
| [METRO-T0++-Base++](https://huggingface.co/gonglinyuan/metro_t0pp_basepp)  | 256M         | 160G corpus      | T0++ Train             |
| [METRO-T0-Large++](https://huggingface.co/gonglinyuan/metro_t0_largepp)   | 775M         | 160G corpus      | T0 Train               |
| [METRO-T0+-Large++](https://huggingface.co/gonglinyuan/metro_t0p_largepp)  | 775M         | 160G corpus      | T0+ Train              |
| [METRO-T0++-Large++](https://huggingface.co/gonglinyuan/metro_t0pp_largepp) | 775M         | 160G corpus      | T0++ Train             |


## Citation

If you find the code and models useful for your research, please cite the following paper:

```
@misc{gong2023modelgenerated,
      title={Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers}, 
      author={Linyuan Gong and Chenyan Xiong and Xiaodong Liu and Payal Bajaj and Yiqing Xie and Alvin Cheung and Jianfeng Gao and Xia Song},
      year={2023},
      eprint={2305.12567},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2305.12567}
}
```