File size: 8,406 Bytes
a838dd3
ff59f31
3df6eba
 
a838dd3
 
b1c6b96
 
567b3e0
 
b1c6b96
d5f7063
1f8c71e
b1c6b96
d5f7063
b1c6b96
d5f7063
 
 
 
06bfccb
 
af5412b
 
 
567b3e0
968c0e7
567b3e0
051457e
 
a838dd3
 
 
3df6eba
94cfd32
 
 
 
 
a838dd3
94cfd32
 
 
 
 
 
 
 
 
a838dd3
2d791bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94cfd32
bbc6d74
a838dd3
94cfd32
a838dd3
acd926a
809adbb
acd926a
 
 
809adbb
94cfd32
 
a838dd3
94cfd32
acd926a
 
 
94cfd32
3b1edd6
 
8c6f40a
 
 
 
acd926a
8c6f40a
 
 
6c8cc5d
 
 
 
 
 
 
657932a
6c8cc5d
8c6f40a
6c8cc5d
 
 
 
 
 
 
 
 
 
 
 
3b1edd6
 
80c36fa
3b1edd6
 
 
 
 
6c8cc5d
3b1edd6
 
 
 
 
0c1debd
 
ddf4a55
0c1debd
 
 
 
 
 
 
 
 
 
 
3b1edd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
license: apache-2.0
language:
- en
model-index:
- name: Graphcore/gptj-mnli
  results:
  - task:
      name: Text Classification
      type: text-classification
    dataset:
      name: GLUE MNLI
      type: glue
      split: validation_mismatched
      args: mnli
    metrics:
    - name: Accuracy
      type: accuracy
      value: 0.825
      config: mnli_mismatched
datasets:
- glue
tags:
- pytorch
- causal-lm
- text-classification
- text-generation
pipeline_tag: text-generation
widget:
- text: "mnli hypothesis: Your contributions were of no help with our students' education. premise: Your contribution helped make it possible for us to provide our students with a quality education. target:"
---

# Graphcore/gptj-mnli
This model is the fine-tuned version of [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) on the [GLUE MNLI dataset](https://huggingface.co/datasets/glue#mnli).
MNLI dataset consists of pairs of sentences, a *premise* and a *hypothesis*.
The task is to predict the relation between the premise and the hypothesis, which can be:
- `entailment`: hypothesis follows from the premise,
- `contradiction`: hypothesis contradicts the premise,
- `neutral`: hypothesis and premise are unrelated.

We finetune the model as a Causal Language Model (CLM): given a sequence of tokens, the task is to predict the next token.
To achieve this, we create a stylised prompt string, following the approach of [T5 paper](https://arxiv.org/pdf/1910.10683.pdf).
```shell
mnli hypothesis: {hypothesis} premise: {premise} target: {class_label} <|endoftext|>
```
For example:
```
mnli hypothesis: Your contributions were of no help with our students' education. premise: Your contribution helped make it possible for us to provide our students with a quality education. target: contradiction <|endoftext|>
```

## Model description

GPT-J 6B is a transformer model trained using Ben Wang's [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax/). "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.

<figure>

| Hyperparameter       | Value      |
|----------------------|------------|
| \\(n_{parameters}\\) | 6053381344 |
| \\(n_{layers}\\)     | 28&ast;    |
| \\(d_{model}\\)      | 4096       |
| \\(d_{ff}\\)         | 16384      |
| \\(n_{heads}\\)      | 16         |
| \\(d_{head}\\)       | 256        |
| \\(n_{ctx}\\)        | 2048       |
| \\(n_{vocab}\\)      | 50257/50400&dagger; (same tokenizer as GPT-2/3)  |
| Positional Encoding  | [Rotary Position Embedding (RoPE)](https://arxiv.org/abs/2104.09864) |
| RoPE Dimensions      | [64](https://github.com/kingoflolz/mesh-transformer-jax/blob/f2aa66e0925de6593dcbb70e72399b97b4130482/mesh_transformer/layers.py#L223) |
<figcaption><p><strong>&ast;</strong> Each layer consists of one feedforward block and one self attention block.</p>
<p><strong>&dagger;</strong> Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer.</p></figcaption></figure>

The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model
dimension is split into 16 heads, each with a dimension of 256. Rotary Position Embedding (RoPE) is applied to 64
dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as
GPT-2/GPT-3.

[EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B), our starting point for finetuning, is trained on [the Pile](https://pile.eleuther.ai), a large-scale curated dataset created by [EleutherAI](https://www.eleuther.ai).

## Fine-tuning and validation data
Fine tuning is done using the `train` split of the GLUE MNLI dataset and the performance is measured using the [validation_mismatched](https://huggingface.co/datasets/glue#mnli_mismatched) split.

`validation_mismatched` means validation examples are not derived from the same sources as those in the training set and therefore not closely resembling any of the examples seen at training time.

Data splits for the mnli dataset are the following

|train |validation_matched|validation_mismatched|
|-----:|-----------------:|--------------------:|
|392702|              9815|                 9832|

## Fine-tuning procedure
Fine tuned on a Graphcore IPU-POD64 using `popxl`.     

Prompt sentences are tokenized and packed together to form 1024 token sequences, following [HF packing algorithm](https://github.com/huggingface/transformers/blob/v4.20.1/examples/pytorch/language-modeling/run_clm.py). No padding is used.
The packing process works in groups of 1000 examples and discards any remainder from each group that isn't a whole sequence.
For the 392,702 training examples this gives a total of 17,762 sequences per epoch. 

Since the model is trained to predict the next token, labels are simply the input sequence shifted by one token.
Given the training format, no extra care is needed to account for different sequences: the model does not need to know which sentence a token belongs to.

### Hyperparameters:
- optimiser: AdamW (beta1: 0.9, beta2: 0.999, eps: 1e-6, weight decay: 0.0, learning rate: 5e-6)
- learning rate schedule: warmup schedule (min: 1e-7, max: 5e-6, warmup proportion: 0.005995)
- batch size: 128
- training steps: 300. Each epoch consists of ceil(17,762/128) steps, hence 300 steps are approximately 2 epochs.  

## Performance
The resulting model matches SOTA performance with 82.5% accuracy.
```
Total number of examples                 9832
Number with badly formed result          0
Number with incorrect result             1725
Number with correct result               8107 
[82.5%]

example 0 = {'prompt_text': "mnli hypothesis: Your contributions were of no help with our students' education. premise: Your contribution helped make it possible for us to provide our students with a quality education. target:", 'class_label': 'contradiction'}
result = {'generated_text': ' contradiction'}

First 10 generated_text and expected class_label results:
 0: 'contradiction'                          contradiction
 1: 'contradiction'                          contradiction
 2: 'entailment'                             entailment
 3: 'contradiction'                          contradiction
 4: 'entailment'                             entailment
 5: 'entailment'                             entailment
 6: 'contradiction'                          contradiction
 7: 'contradiction'                          contradiction
 8: 'entailment'                             neutral
 9: 'contradiction'                          contradiction
```
## How to use
The model can be easily loaded using AutoModelForCausalLM.
You can use the pipeline API for text generation.

```python
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6B')
hf_model = AutoModelForCausalLM.from_pretrained("Graphcore/gptj-mnli", pad_token_id=tokenizer.eos_token_id)
generator =  pipeline('text-generation', model=hf_model, tokenizer=tokenizer)
prompt = "mnli hypothesis: Your contributions were of no help with our students' education." \
         "premise: Your contribution helped make it possible for us to provide our students with a quality education. target:"
out = generator(prompt, return_full_text=False, max_new_tokens=5, top_k=1)
# [{'generated_text': ' contradiction'}]
```

You can create prompt-like inputs starting from GLUE MNLI dataset using functions provided in the `data_utils.py` script.
```python
from datasets import load_dataset
from data_utils import form_text, split_text

dataset = load_dataset('glue', 'mnli', split='validation_mismatched')
dataset = dataset.map(
    form_text, remove_columns=['hypothesis', 'premise','label', 'idx'])
# dataset[0] {'text': "mnli hypothesis: Your contributions were of no help with our students' education. premise: Your contribution helped make it possible for us to provide our students with a quality education. target: contradiction<|endoftext|>"}
dataset = dataset.map(split_text, remove_columns=['text'])
# dataset[0] {'prompt_text': "mnli hypothesis: Your contributions were of no help with our students' education. premise: Your contribution helped make it possible for us to provide our students with a quality education. target:",
#             'class_label': 'contradiction'}
```