|
""" |
|
Usage: |
|
python train_unigram.py --export_to_hub |
|
|
|
Note that you'd need to execute `huggingface-cli login` before if you passed export_to_hub. |
|
|
|
Reference: |
|
https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb |
|
""" |
|
|
|
import argparse |
|
import logging |
|
|
|
import datasets |
|
import torch |
|
from datasets import Dataset |
|
from tokenizers import ( |
|
Tokenizer, |
|
decoders, |
|
normalizers, |
|
pre_tokenizers, |
|
processors, |
|
) |
|
from tokenizers.models import Unigram |
|
from tokenizers.trainers import UnigramTrainer |
|
from transformers import AlbertTokenizerFast |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description="Train a unigram tokenizer on the wikitext dataset." |
|
) |
|
parser.add_argument( |
|
"-bs", |
|
"--batch-size", |
|
type=int, |
|
default=1000, |
|
help="Batch size during training.", |
|
) |
|
parser.add_argument( |
|
"-vs", |
|
"--vocab-size", |
|
type=int, |
|
default=10000, |
|
help="Size of the desired vocabulary.", |
|
) |
|
parser.add_argument( |
|
"--limit", |
|
default=None, |
|
type=int, |
|
help="Limit the number of shards (used for debugging).", |
|
) |
|
parser.add_argument( |
|
"--export_to_hub", |
|
action="store_true", |
|
) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def get_unigram_tokenizer() -> Tokenizer: |
|
tokenizer = Tokenizer(Unigram()) |
|
tokenizer.normalizer = normalizers.Sequence( |
|
[normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] |
|
) |
|
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() |
|
return tokenizer |
|
|
|
|
|
def get_unigram_trainer(vocab_size: int) -> UnigramTrainer: |
|
trainer = UnigramTrainer( |
|
unk_token="<unk>", |
|
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"], |
|
vocab_size=vocab_size, |
|
) |
|
return trainer |
|
|
|
|
|
def main(args): |
|
wikitext = datasets.load_dataset( |
|
"wikitext", "wikitext-103-raw-v1", split="train" |
|
) |
|
|
|
if args.limit is not None: |
|
wikitext = wikitext[: args.limit] |
|
wikitext = Dataset.from_dict(wikitext) |
|
logging.info(f"Limiting the dataset to {args.limit} entries.") |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
wikitext, num_workers=0, batch_size=args.batch_size |
|
) |
|
logging.info("Training the tokenizer.") |
|
tokenizer = get_unigram_tokenizer() |
|
trainer = get_unigram_trainer(args.vocab_size) |
|
tokenizer.train_from_iterator(dataloader, trainer=trainer) |
|
logging.info("Tokenizer training complete!") |
|
|
|
cls_token_id = tokenizer.token_to_id("[CLS]") |
|
sep_token_id = tokenizer.token_to_id("[SEP]") |
|
tokenizer.post_processor = processors.TemplateProcessing( |
|
single="[CLS]:0 $A:0 [SEP]:0", |
|
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", |
|
special_tokens=[ |
|
("[CLS]", cls_token_id), |
|
("[SEP]", sep_token_id), |
|
], |
|
) |
|
tokenizer.decoder = decoders.Metaspace() |
|
|
|
if args.export_to_hub: |
|
logging.info("Exporting the trained tokenzier to Hub.") |
|
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) |
|
new_tokenizer.push_to_hub("sayakpaul/unigram-tokenizer-wikitext") |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args) |
|
|