File size: 3,068 Bytes
ce41b9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Setting up a Google Cloud TPU VM for training a tokenizer
## TPU VM Configurations
To start off follow the guide from the Flax/JAX community week 2021 [here](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-setup-tpu-vm), but **NOTE** modify all the `pip` commands to `pip3`.
Some might encounter this error message:
```
Building wheel for jax (setup.py) ... error
ERROR: Command errored out with exit status 1:
command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
cwd: /tmp/pip-install-lwseckn1/jax/
Complete output (6 lines):
usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
or: setup.py --help [cmd1 cmd2 ...]
or: setup.py --help-commands
or: setup.py cmd --help
error: invalid command 'bdist_wheel'
----------------------------------------
ERROR: Failed building wheel for jax
```
If encountering the error message run the following commands:
```
pip3 install --upgrade clu
pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
Then give your user sudo rights:
```bash
chmod a+rwx /tmp/*
chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-)
```
Afterwards you can verify the installation by either running the following script:
```python
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
dummy_input = next(iter(dataset))["text"]
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)
```
Or by running the following `python` commands:
```python
import jax
jax.devices()
```
## Training the tokenizer
To train the tokenizer run the `train_tokenizer.py` script:
```bash
python3 train_tokenizer.py
```
### Problems while developing the script:
- Loading the '*mc4*' dataset using the `load_dataset()` from HugginFace's dataset package `datasets` was not able to load multiple language in one line of code, as otherwise specified [here](https://huggingface.co/datasets/mc4). It was thus chosen to load each language and concatenate them.
- Furthermore, it seems that even though you predefine a subset-split using the `split` argument, the entire dataset still needs to be downloaded.
- Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.
|