roberta-base-danish / md_logs /train_tokenizer.md
maltehb
added all files
ce41b9a

Setting up a Google Cloud TPU VM for training a tokenizer

TPU VM Configurations

To start off follow the guide from the Flax/JAX community week 2021 here, but NOTE modify all the pip commands to pip3.

Some might encounter this error message:

 Building wheel for jax (setup.py) ... error
  ERROR: Command errored out with exit status 1:
   command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; __file__='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo
       cwd: /tmp/pip-install-lwseckn1/jax/
  Complete output (6 lines):
  usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...]
     or: setup.py --help [cmd1 cmd2 ...]
     or: setup.py --help-commands
     or: setup.py cmd --help
  
  error: invalid command 'bdist_wheel'
  ----------------------------------------
  ERROR: Failed building wheel for jax

If encountering the error message run the following commands:

pip3 install --upgrade clu
pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Then give your user sudo rights:

chmod a+rwx /tmp/*
chmod a+rwx /tmp/tpu_logs/*  # Just to be sure ;-)

Afterwards you can verify the installation by either running the following script:

from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

dummy_input = next(iter(dataset))["text"]

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)

Or by running the following python commands:

import jax
jax.devices()

Training the tokenizer

To train the tokenizer run the train_tokenizer.py script:

python3 train_tokenizer.py

Problems while developing the script:

  • Loading the 'mc4' dataset using the load_dataset() from HugginFace's dataset package datasets was not able to load multiple language in one line of code, as otherwise specified here. It was thus chosen to load each language and concatenate them.
  • Furthermore, it seems that even though you predefine a subset-split using the split argument, the entire dataset still needs to be downloaded.
  • Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.