roberta-base-danish / md_logs /train_tokenizer.md
maltehb
Merge branch 'main' of https://huggingface.co/flax-community/roberta-base-danish into main
4459c08
|
raw
history blame
3.07 kB

Setting up a Google Cloud TPU VM for training a tokenizer

TPU VM Configurations

To start off follow the guide from the Flax/JAX community week 2021 here, but NOTE modify all the pip commands to pip3.

Some might encounter this error message: ``` Building wheel for jax (setup.py) ... error ERROR: Command errored out with exit status 1: command: /home/patrick/patrick/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"'; file='"'"'/tmp/pip-install-lwseckn1/jax/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(file);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, file, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-pydotzlo cwd: /tmp/pip-install-lwseckn1/jax/ Complete output (6 lines): usage: setup.py [global_opts] cmd1 [cmd1_opts] [cmd2 [cmd2_opts] ...] or: setup.py --help [cmd1 cmd2 ...] or: setup.py --help-commands or: setup.py cmd --help
error: invalid command 'bdist_wheel'

ERROR: Failed building wheel for jax


If encountering the error message run the following commands:

pip3 install --upgrade clu pip3 install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html


Then give your user sudo rights:
```bash
chmod a+rwx /tmp/*
chmod a+rwx /tmp/tpu_logs/*  # Just to be sure ;-)

Afterwards you can verify the installation by either running the following script:

from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

dummy_input = next(iter(dataset))["text"]

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)

Or by running the following python commands:

import jax
jax.devices()

Training the tokenizer

To train the tokenizer run the train_tokenizer.py script:

python3 train_tokenizer.py

Problems while developing the script:

  • Loading the 'mc4' dataset using the load_dataset() from HugginFace's dataset package datasets was not able to load multiple language in one line of code, as otherwise specified here. It was thus chosen to load each language and concatenate them.
  • Furthermore, it seems that even though you predefine a subset-split using the split argument, the entire dataset still needs to be downloaded.
  • Some bug occured when downloading the danish dataset, and we then had to force a redownload to mitigate the bug, and make the VM download it.