--- license: afl-3.0 language: en tags: - t5 datasets: - wikipedia --- # chunked T5 - base (cT5-base) Github: https://github.com/mtreviso/chunked-t5 A T5 model that uses a new loss where a special end-of-chunk token `` is appended after sentinel tokens. The decoder has to predict the full input with masked tokens followed by ``. This allows a much faster auto-regressive generation since the decoder can predict multiple tokens in parallel. For example, for the input `the quick brown fox jumps over the lazy dog`: ``` encoder: the fox jumps the lazy dog T5 decoder : quick brown over cT5 decoder: quick brown over ``` The generation may look like this for T5 and cT5: ``` T5: T5: quick T5: quick brown T5: quick brown T5: quick brown over T5: quick brown over T5: quick brown over cT5: cT5: quick over cT5: quick brown over cT5: quick brown over ``` In the original T5, the decoder is called \\(n_s + 1 + \sum_i |s_i|\\) times autoregressively, where \\(n_s\\) is the number of sentinel tokens and \\(s_1,...,s_{n_s}\\) are the predicted chunks. In contrast, cT5's decoder is called just \\(max_i |s_i| + 1\\) times. The generation stops when all sentences were fully translated to complete chunks, i.e., until all `` tokens were generated. Alternatively, you can also set `max_chunk_size` to manually force the model to stop after generating a chunk with `max_chunk_size` tokens. The overhead of calling the decoder with a longer input is less pronounced since this computation can be parallelized in GPUs/TPUs. ## Training details cT5 models used T5's weights as a starting point, and then it was finetuned on the English [wikipedia](https://huggingface.co/datasets/wikipedia) for 3 epochs, achieving ~74% validation accuracy (ct5-base). The training script is in JAX + Flax and can be found in `pretrain_ct5.py`. Flax checkpoints can be converted to PyTorch via `convert_flax_to_pytorch.py [flax_dirname]`. ## Checkpoints - ct5-small: https://huggingface.co/mtreviso/ct5-small-en-wiki - ct5-base: https://huggingface.co/mtreviso/ct5-base-en-wiki - ct5-large: todo ## Usage ```python from transformers import AutoTokenizer from modeling_ct5 import CT5ForConditionalGeneration tokenizer = AutoTokenizer.from_pretrained("mtreviso/ct5-base-en-wiki") model = CT5ForConditionalGeneration.from_pretrained("mtreviso/ct5-base-en-wiki") ``` For training: ```python input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids labels = tokenizer(" man the ", return_tensors="pt").input_ids outputs = model(input_ids=input_ids, labels=labels) loss = outputs.loss logits = outputs.logits ``` For generation: ```python texts = [ "The walks in park", "UN Chief says there is no way to in Syria", ] input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids generated_ids = model.generate( input_ids, use_cache=False, # important to set to False to avoid caching eoc_token_id=tokenizer.vocab[''], # important to set to the correct end-of-chunk id max_chunk_size=5, # the default is 9999999, which is a large number ) ``` This will produce the following tokens: ```python >> ['', '', '▁Walking', '▁Trail', '', '', '▁the', '', '', ''] >> ['', '', '▁treat', '▁Syria', '', '', '', '', '', ''] ``` You have to pass `use_cache=False` to `generate()` in order to avoid caching during the generation procedure as caching is not available for parallel decoding. Currently, parallel decoding is only supported for PyTorch (greedy search, greedy sampling, beam search, beam sampling) and JAX (greedy search and greedy sampling). **Note on the beam search implementation**: my beam search implementation is slower than optimal. This is because I use the structures provided by HuggingFace's implementation, namely, BeamScores and BeamHypotheses to store the beam search results for each chunk in the input. In other words, my implementation computes independent "beams" for each chunk rather than for each input sequence. It is possible to make it faster by using a custom BeamScores and BeamHypotheses class, but I haven't done that yet. ## Evaluation See the notebook `evaluate_ct5.ipynb` for an example of how to evaluate cT5 in terms of accuracy and perplexity. The notebook `profile.ipynb` shows how to profile the model to get runtimes. Here is a comparison between cT5-small and T5-small on a subset of the WikiText-103 dataset using deterministic greedy search: | Model | Exact match ↑ | Edit distance ratio ↑ | Perplexity ↓ | Time (seconds) ↓ | |-------|---------------|----------------------|--------------|-----------------| | T5-small | 0.11 | 0.60 | 2.22 | 44.71 | | cT5-small | 0.09 | 0.58 | 1.48 | 10.63 | On this toy dataset, cT5-small has a lower perplexity while being faster than T5-small. However, more experiments are needed for a rigorous evaluation. If you are interested in applying cT5 to real data, please contact me.