# Setting up a Google Cloud TPU VM for training a tokenizer ## TPU VM Configurations To start off follow the guide from the Flax/JAX community week 2021 [here](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm), but **NOTE** modify all the `pip` commands to `pip3`. Some might encounter this error message: ``` Building wheel for jax (setup.py) ... error ERROR: Command errored out with exit status 1: command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo cwd: /tmp/pip-install-lwseckn1/jax/ Complete output (6 lines): usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...] or: setup.py --help [cmd1 cmd2 ...] or: setup.py --help-commands or: setup.py cmd --help error: invalid command 'bdist_wheel' ---------------------------------------- ERROR: Failed building wheel for jax ``` If encountering the error message run the following commands: ``` pip3 install --upgrade clu pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` Then give your user sudo rights: ```bash chmod a+rwx /tmp/* chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-) ``` Afterwards you can verify the installation by either running the following script: ```python from transformers import FlaxRobertaModel, RobertaTokenizerFast from datasets import load_dataset import jax dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True) dummy_input = next(iter(dataset))["text"] tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10] model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown") # run a forward pass, should return an object `FlaxBaseModelOutputWithPooling` model(input_ids) ``` Or by running the following `python` commands: ```python import jax jax.devices() ``` ## Training the tokenizer To train the tokenizer run the `train_tokenizer.py` script: ```bash python3 train_tokenizer.py ``` ### Problems while developing the script: - Loading the '*mc4*' dataset using the `load_dataset()` from HugginFace's dataset package `datasets` was not able to load multiple language in one line of code, as otherwise specified [here](https://huggingface.co/datasets/mc4). It was thus chosen to load each language and concatenate them. - Furthermore, it seems that even though you predefine a subset-split using the `split` argument, the entire dataset still needs to be downloaded. - Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.