yhavinga's picture
Switch to streamlit with markdown, add T5X pre-trained models


mC4 dataset

Together with the T5 model architecture and SeqIO, the T5 authors also created and released the multilingual mC4 dataset. It was made available by AllenNLP on the HuggingFace Dataset hub. Our team confirmed that the Dutch portion of the mC4 dataset was deduplicated, and we cleaned the Dutch portion of the mC4 dataset using code adapted from the TensorFlow C4 dataset. The resulting mc4_nl_cleaned dataset on the HuggingFace hub has configs for several sizes, and also configs for mixed Dutch and English texts, e.g. micro_en_nl. The _en_nl configs were added to accommodate multi-language pre-training with the Huggingface pre-training script, that accepts only a single dataset as input. Cleaned English C4 is roughly 5 times larger than its Dutch counterpart. Therefore, interleaving the datasets in a 1:1 ratio results in discarding approximately 80% of the English data. (When pre-training with T5X and SeqIO, it is possible to define task mixtures that include multiple datasets, so these _en_nl configs are not needed.)

The full, cleaned Dutch mC4 dataset is 151GB and remains (as of June 2022) the largest available Dutch corpus on the HuggingFace Dataset hub.

Additional books, Wikipedia and Dutch news articles datasets

The t5_1_1 and ul2 models have also been trained on Dutch books, the Dutch subset of Wikipedia (2022-03-20), the English subset of Wikipedia (2022-03-01), and a subset of "mc4_nl_cleaned" containing only texts from Dutch and Belgian newspapers. Mixing in the these datasets was done to bias the model towards descriptions of events in the Netherlands and Belgium.

Pre-Training Objectives

The T5 models are pre-trained using the span corruption denoising objective. 15% of the tokens in the text are masked, and each span of masked tokens is replaced with a special token known as a sentinel token, where each span is assigned its own sentinel token. The model is then trained to predict for each sentinel token the original text that was replaced by the sentinel tokens.

The UL2 models are pre-trained with the Mixture-of-Denoisers (MoD) objective, that combines diverse pre-training paradigms together. UL2 frames different objective functions for training language models as denoising tasks, where the model has to recover missing sub-sequences of a given input. During pre-training it uses a novel mixture-of-denoisers that samples from a varied set of such objectives, each with different configurations. UL2 is trained using a mixture of three denoising tasks:

  1. R-denoising (or regular span corruption), which emulates the standard T5 span corruption objective;
  2. X-denoising (or extreme span corruption); and
  3. S-denoising (or sequential PrefixLM).

Pre-training software

Huggingface run_t5_mlm_flax.py

All models except t5_1_1 and ul2 were pre-trained using the Huggingface run_t5_mlm_flax.py script. This script is a good fit if you want to get a grasp what's needed to pre-train a language model with Flax and Jax, since all data preparation, model instantiation, loss function, and training loop are contained in a single file.

Google's T5X

The Dutch t5_1_1 and ul2 models were pre-trained using T5X. This is a modular framework that can be used for pre-training, fine-tuning, and evaluation of T5 models. Because of its modular and pluggable design, by only supplying a few configuration and code files, it is possible to pre-train with your own definitions. It is even possible to define custom neural network layers and architectures, though I did not do this and only pre-trained the default T5 encoder-decoder architecture, and varied only the pre-training objective, and the datasets used and mixed with SeqIO.

Conversion script from T5X to HF

The T5X models were converted to Huggingface Flax T5 format using a script that was adapted from the T5X checkpoint to HuggingFace Flax conversion script. This script was modified to cast weights to bf16, and to also convert to pytorch format. For this conversion to be successful, the T5X model had to be saved with use_gda=False set in the GIN file.