maltehb
Merge branch 'main' of https://huggingface.co/flax-community/roberta-base-danish into main
4459c08
Setting up a Google Cloud TPU VM for training a tokenizer
TPU VM Configurations
To start off follow the guide from the Flax/JAX community week 2021 here, but NOTE modify all the pip
commands to pip3
.
Some might encounter this error message:
```
Building wheel for jax (setup.py) ... error
ERROR: Command errored out with exit status 1:
command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; file='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(file);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, file, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
cwd: /tmp/pip-install-lwseckn1/jax/
Complete output (6 lines):
usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
or: setup.py --help [cmd1 cmd2 ...]
or: setup.py --help-commands
or: setup.py cmd --help
error: invalid command 'bdist_wheel'
ERROR: Failed building wheel for jax
If encountering the error message run the following commands:
pip3 install --upgrade clu pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Then give your user sudo rights:
```bash
chmod a+rwx /tmp/*
chmod a+rwx /tmp/tpu_logs/* # Just to be sure ;-)
Afterwards you can verify the installation by either running the following script:
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
dummy_input = next(iter(dataset))["text"]
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)
Or by running the following python
commands:
import jax
jax.devices()
Training the tokenizer
To train the tokenizer run the train_tokenizer.py
script:
python3 train_tokenizer.py
Problems while developing the script:
- Loading the 'mc4' dataset using the
load_dataset()
from HugginFace's dataset packagedatasets
was not able to load multiple language in one line of code, as otherwise specified here. It was thus chosen to load each language and concatenate them. - Furthermore, it seems that even though you predefine a subset-split using the
split
argument, the entire dataset still needs to be downloaded. - Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.