|
# Setting up a Google Cloud TPU VM for training a tokenizer |
|
|
|
## TPU VM Configurations |
|
To start off follow the guide from the Flax/JAX community week 2021 [here](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm), but **NOTE** modify all the `pip` commands to `pip3`. |
|
|
|
Some might encounter this error message: |
|
``` |
|
Building wheel for jax (setup.py) ... error |
|
ERROR: Command errored out with exit status 1: |
|
command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo |
|
cwd: /tmp/pip-install-lwseckn1/jax/ |
|
Complete output (6 lines): |
|
usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...] |
|
or: setup.py --help [cmd1 cmd2 ...] |
|
or: setup.py --help-commands |
|
or: setup.py cmd --help |
|
|
|
error: invalid command 'bdist_wheel' |
|
---------------------------------------- |
|
ERROR: Failed building wheel for jax |
|
``` |
|
|
|
If encountering the error message run the following commands: |
|
``` |
|
pip3 install --upgrade clu |
|
pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
|
``` |
|
|
|
Then give your user sudo rights: |
|
```bash |
|
chmod a+rwx /tmp/* |
|
chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-) |
|
``` |
|
|
|
Afterwards you can verify the installation by either running the following script: |
|
|
|
```python |
|
from transformers import FlaxRobertaModel, RobertaTokenizerFast |
|
from datasets import load_dataset |
|
import jax |
|
|
|
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True) |
|
|
|
dummy_input = next(iter(dataset))["text"] |
|
|
|
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") |
|
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10] |
|
|
|
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown") |
|
|
|
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling` |
|
model(input_ids) |
|
``` |
|
|
|
Or by running the following `python` commands: |
|
```python |
|
import jax |
|
jax.devices() |
|
``` |
|
|
|
## Training the tokenizer |
|
To train the tokenizer run the `train_tokenizer.py` script: |
|
```bash |
|
python3 train_tokenizer.py |
|
``` |
|
|
|
### Problems while developing the script: |
|
- Loading the '*mc4*' dataset using the `load_dataset()` from HugginFace's dataset package `datasets` was not able to load multiple language in one line of code, as otherwise specified [here](https://huggingface.co/datasets/mc4). It was thus chosen to load each language and concatenate them. |
|
- Furthermore, it seems that even though you predefine a subset-split using the `split` argument, the entire dataset still needs to be downloaded. |
|
- Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it. |
|
|