# Language model training examples The following example showcases how to train a language model from scratch using the JAX/Flax backend. JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are **immutable** and updated in a purely functional way which enables simple and efficient model parallelism. ## Causal language modeling In the following, we demonstrate how to train an auto-regressive causal transformer model in JAX/Flax. More specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian on a single TPUv3-8. to pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2) in Norwegian on a single TPUv3-8 pod. The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets. Let's start by creating a model repository to save the trained model and logs. Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like. You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that you are logged in) or via the command line: ``` huggingface-cli repo create norwegian-gpt2 ``` Next we clone the model repository to add the tokenizer and model files. ``` git clone https://huggingface.co//norwegian-gpt2 ``` To ensure that all tensorboard traces will be uploaded correctly, we need to track them. You can run the following command inside your model repo to do so. ``` cd norwegian-gpt2 git lfs track "*tfevents*" ``` Great, we have set up our model repository. During training, we will automatically push the training logs and model weights to the repo. Next, let's add a symbolic link to the `run_clm_flax.py`. ```bash export MODEL_DIR="./norwegian-gpt2" ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py ``` ### Train tokenizer In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**. The tokenizer is trained on the complete Norwegian dataset of OSCAR and consequently saved in `${MODEL_DIR}` This can take up to 10 minutes depending on your hardware ☕. ```python from datasets import load_dataset from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer model_dir = "./norwegian-roberta-base" # ${MODEL_DIR} # load dataset dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train") # Instantiate tokenizer tokenizer = ByteLevelBPETokenizer() def batch_iterator(batch_size=1000): for i in range(0, len(dataset), batch_size): yield dataset[i: i + batch_size]["text"] # Customized training tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[ "", "", "", "", "", ]) # Save files to disk tokenizer.save(f"{model_dir}/tokenizer.json") ``` ### Create configuration Next, we create the model's configuration file. This is as simple as loading and storing [`**gpt2**`](https://huggingface.co/gpt2) in the local model folder: ```python from transformers import GPT2Config model_dir = "./norwegian-gpt2" # ${MODEL_DIR} config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0) config.save_pretrained(model_dir) ``` ### Train model Next we can run the example script to pretrain the model: ```bash ./run_clm_flax.py \ --output_dir="${MODEL_DIR}" \ --model_type="gpt2" \ --config_name="${MODEL_DIR}" \ --tokenizer_name="${MODEL_DIR}" \ --dataset_name="oscar" \ --dataset_config_name="unshuffled_deduplicated_no" \ --do_train --do_eval \ --block_size="512" \ --per_device_train_batch_size="64" \ --per_device_eval_batch_size="64" \ --learning_rate="5e-3" --warmup_steps="1000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="20" \ --push_to_hub ``` Training should converge at a loss and perplexity of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8. This should take less than ~21 hours. Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).