File size: 5,765 Bytes
056fcaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
---
license: afl-3.0
language: en
tags:
- t5
datasets:
- wikipedia
---

# chunked T5 - base (cT5-base)

Github: https://github.com/mtreviso/chunked-t5

A T5 model that uses a new loss where a special end-of-chunk token `</c>` is appended after sentinel tokens. 
The decoder has to predict the full input with masked tokens followed by `</c>`. 
This allows a much faster auto-regressive generation since the decoder can predict multiple tokens in parallel.

For example, for the input `the quick brown fox jumps over the lazy dog`:
```
encoder: the <extra_id_0> fox jumps <extra_id_1> the lazy dog

T5 decoder : <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
cT5 decoder: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2>
```

The generation may look like this for T5 and cT5:
```
T5: <extra_id_0>
T5: <extra_id_0> quick
T5: <extra_id_0> quick brown
T5: <extra_id_0> quick brown <extra_id_1>
T5: <extra_id_0> quick brown <extra_id_1> over
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2>
T5: <extra_id_0> quick brown <extra_id_1> over <extra_id_2> </s>

cT5: <extra_id_0> <pad> <extra_id_1> <pad> <extra_id_2> </s>
cT5: <extra_id_0> quick <pad> <extra_id_1> over <pad> <extra_id_2> </s>
cT5: <extra_id_0> quick brown <pad> <extra_id_1> over </c> <extra_id_2> </s>
cT5: <extra_id_0> quick brown </c> <extra_id_1> over </c> <extra_id_2> </s>
```

In the original T5, the decoder is called \\(n_s + 1 + \sum_i |s_i|\\) times autoregressively, 
where \\(n_s\\) is the number of sentinel tokens and \\(s_1,...,s_{n_s}\\) are the predicted chunks. 
In contrast, cT5's decoder is called just \\(max_i |s_i| + 1\\) times. 
The generation stops when all sentences were fully translated to complete chunks, i.e., until all `</c>` tokens were generated. 
Alternatively, you can also set `max_chunk_size` to manually force the model to stop after generating a chunk with `max_chunk_size` tokens.
The overhead of calling the decoder with a longer input is less pronounced since this computation can be parallelized in GPUs/TPUs.

## Training details

cT5 models used T5's weights as a starting point, and then it was finetuned on the 
English [wikipedia](https://huggingface.co/datasets/wikipedia) for 3 epochs, 
achieving ~74% validation accuracy (ct5-base).
The training script is in JAX + Flax and can be found in `pretrain_ct5.py`.

Flax checkpoints can be converted to PyTorch via `convert_flax_to_pytorch.py [flax_dirname]`.


## Checkpoints

- ct5-small: https://huggingface.co/mtreviso/ct5-small-en-wiki
- ct5-base: https://huggingface.co/mtreviso/ct5-base-en-wiki
- ct5-large: todo


## Usage

```python
from transformers import AutoTokenizer
from modeling_ct5 import CT5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("mtreviso/ct5-base-en-wiki")
model = CT5ForConditionalGeneration.from_pretrained("mtreviso/ct5-base-en-wiki")
```

For training:

```python
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
labels = tokenizer("<extra_id_0> man </c> <extra_id_1> the </c> <extra_id_2>", return_tensors="pt").input_ids
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
logits = outputs.logits
```

For generation:

```python
texts = [
    "The <extra_id_0> walks in <extra_id_1> park",
    "UN Chief says there is no way to <extra_id_0> in Syria",
]
input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids
generated_ids = model.generate(
    input_ids, 
    use_cache=False,  # important to set to False to avoid caching
    eoc_token_id=tokenizer.vocab['</c>'],  # important to set to the correct end-of-chunk id
    max_chunk_size=5,  # the default is 9999999, which is a large number
)
```

This will produce the following tokens:
```python
>> ['<pad>', '<extra_id_0>', '▁Walking', '▁Trail', '</c>', '<extra_id_1>', '▁the', '</c>', '<extra_id_2>', '</s>']
>> ['<pad>', '<extra_id_0>', '▁treat', '▁Syria', '</c>', '<extra_id_1>', '</s>', '<pad>', '<pad>', '<pad>']
```

You have to pass `use_cache=False` to `generate()` in order to avoid caching during the generation procedure as caching is not available for parallel decoding. 
Currently, parallel decoding is only supported for PyTorch (greedy search, greedy sampling, beam search, beam sampling) and JAX (greedy search and greedy sampling).

**Note on the beam search implementation**: my beam search implementation is slower than optimal.
This is because I use the structures provided by HuggingFace's implementation, namely, BeamScores and BeamHypotheses to store the beam search results for each chunk in the input.
In other words, my implementation computes independent "beams" for each chunk rather than for each input sequence.
It is possible to make it faster by using a custom BeamScores and BeamHypotheses class, but I haven't done that yet.


## Evaluation

See the notebook `evaluate_ct5.ipynb` for an example of how to evaluate cT5 in terms of accuracy and perplexity.
The notebook `profile.ipynb` shows how to profile the model to get runtimes.

Here is a comparison between cT5-small and T5-small on a subset of the WikiText-103 dataset using deterministic greedy search:

| Model | Exact match ↑ | Edit distance ratio ↑ | Perplexity ↓ | Time (seconds) ↓ |
|-------|---------------|----------------------|--------------|-----------------|
| T5-small | 0.11          | 0.60                 | 2.22         | 44.71           |
| cT5-small | 0.09          | 0.58                 | 1.48         | 10.63           |

On this toy dataset, cT5-small has a lower perplexity while being faster than T5-small. However, more experiments are needed for a rigorous evaluation.

If you are interested in applying cT5 to real data, please contact me.