metro_t0_basepp / README.md
gonglinyuan's picture
Update README.md
d0316af
|
raw
history blame
6.34 kB
metadata
license: mit
language:
  - en
tags:
  - t5
model-index:
  - name: metro_t0_basepp
    results:
      - task:
          type: natural-language-inference
        dataset:
          type: super_glue
          name: RTE
          config: rte
          split: validation
        metrics:
          - type: accuracy
            value: 68.15884476534298
      - task:
          type: natural-language-inference
        dataset:
          type: super_glue
          name: CB
          config: cb
          split: validation
        metrics:
          - type: accuracy
            value: 63.2142857142857
      - task:
          type: natural-language-inference
        dataset:
          type: anli
          name: ANLI R1
          split: dev_r1
        metrics:
          - type: accuracy
            value: 34.92
      - task:
          type: natural-language-inference
        dataset:
          type: anli
          name: ANLI R2
          split: dev_r2
        metrics:
          - type: accuracy
            value: 33.806666666666665
      - task:
          type: natural-language-inference
        dataset:
          type: anli
          name: ANLI R3
          split: dev_r3
        metrics:
          - type: accuracy
            value: 36.81666666666667
      - task:
          type: coreference-resolution
        dataset:
          type: super_glue
          name: WSC
          config: wsc.fixed
          split: validation
        metrics:
          - type: accuracy
            value: 60.480769230769226
      - task:
          type: coreference-resolution
        dataset:
          type: winogrande
          name: Winogrande XL
          config: winogrande_xl
          split: validation
        metrics:
          - type: accuracy
            value: 52.028413575374906
      - task:
          type: multiple-choice-qa
        dataset:
          type: super_glue
          name: COPA
          config: copa
          split: validation
        metrics:
          - type: accuracy
            value: 78.5
      - task:
          type: multiple-choice-qa
        dataset:
          type: story_cloze
          name: StoryCloze 2016
          config: '2016'
          split: validation
        metrics:
          - type: accuracy
            value: 89.22501336183858
      - task:
          type: multiple-choice-qa
        dataset:
          type: hellaswag
          name: HellaSwag
          split: validation
        metrics:
          - type: accuracy
            value: 27.67625970922127
      - task:
          type: word-sense-disambiguation
        dataset:
          type: super_glue
          name: WiC
          config: wic
          split: validation
        metrics:
          - type: accuracy
            value: 50.877742946708466

Official repository: https://github.com/gonglinyuan/metro_t0

METRO-T0

Paper: Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers (ACL 2023)

METRO-T0 is a T5-style text-to-text Transformer pretrained using model-generated pretraining signals, prompt-finetuned on a family of public NLP tasks proposed in T0. METRO-T0 is highly parameter efficient. For example, METRO-T0-Large++ (775M parameters) outperforms GPT-3 (175B parameters) and T0-3B (3B parameters) on a wide range of NLP tasks.

The architecture of METRO-T0 during pretraining using BERT as the auxiliary model to generate signals

Prompt learning results of METRO-T0 versus our T0 baseline and T03B by Sanh et al. (2022) on 4 tasks  in the T0 Eval benchmark. Each point denotes the accuracy using one prompt template, except that the median accuracy over all templates of T03B is indicated by the blue point. The plots of other tasks are in our paper.

Use METRO-T0-Base++

To use METRO-T0-Base++ in PyTorch (Python 3.7+, PyTorch 1.12+ and transformers 4.17+ are prerequisites), refer to the code snippet below:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained("gonglinyuan/metro_t0_basepp", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("gonglinyuan/metro_t0_basepp", trust_remote_code=True)

input_text = "Is this review positive or negative? Review: this is the best cast iron skillet you will ever buy"
inputs = tokenizer([input_text], max_length=512, truncation=True, add_special_tokens=True, return_tensors="pt").input_ids
outputs = model.generate(inputs, max_new_tokens=256, do_sample=False)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))  # expected: positive

Other METRO-T0 Models

# Parameters Pretraining Data Prompt-Finetuning Data
METRO-T0-Base 226M Wikibook (16G) T0 Train
METRO-T0+-Base 226M Wikibook (16G) T0+ Train
METRO-T0++-Base 226M Wikibook (16G) T0++ Train
METRO-T0-Base++ 256M 160G corpus T0 Train
METRO-T0+-Base++ 256M 160G corpus T0+ Train
METRO-T0++-Base++ 256M 160G corpus T0++ Train
METRO-T0-Large++ 775M 160G corpus T0 Train
METRO-T0+-Large++ 775M 160G corpus T0+ Train
METRO-T0++-Large++ 775M 160G corpus T0++ Train

Citation

If you find the code and models useful for your research, please cite the following paper:

@misc{gong2023modelgenerated,
      title={Model-Generated Pretraining Signals Improves Zero-Shot Generalization of Text-to-Text Transformers}, 
      author={Linyuan Gong and Chenyan Xiong and Xiaodong Liu and Payal Bajaj and Yiqing Xie and Alvin Cheung and Jianfeng Gao and Xia Song},
      year={2023},
      eprint={2305.12567},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2305.12567}
}