unigram-tokenizer-wikitext / train_unigram.py
sayakpaul's picture
sayakpaul HF staff
Upload train_unigram.py
256477b
"""
Usage:
python train_unigram.py --export_to_hub
Note that you'd need to execute `huggingface-cli login` before if you passed export_to_hub.
Reference:
https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb
"""
import argparse
import logging
import datasets
import torch
from datasets import Dataset
from tokenizers import (
Tokenizer,
decoders,
normalizers,
pre_tokenizers,
processors,
)
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from transformers import AlbertTokenizerFast
def parse_args():
parser = argparse.ArgumentParser(
description="Train a unigram tokenizer on the wikitext dataset."
)
parser.add_argument(
"-bs",
"--batch-size",
type=int,
default=1000,
help="Batch size during training.",
)
parser.add_argument(
"-vs",
"--vocab-size",
type=int,
default=10000,
help="Size of the desired vocabulary.",
)
parser.add_argument(
"--limit",
default=None,
type=int,
help="Limit the number of shards (used for debugging).",
)
parser.add_argument(
"--export_to_hub",
action="store_true",
)
args = parser.parse_args()
return args
def get_unigram_tokenizer() -> Tokenizer:
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence(
[normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]
)
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
return tokenizer
def get_unigram_trainer(vocab_size: int) -> UnigramTrainer:
trainer = UnigramTrainer(
unk_token="<unk>",
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
vocab_size=vocab_size,
)
return trainer
def main(args):
wikitext = datasets.load_dataset(
"wikitext", "wikitext-103-raw-v1", split="train"
)
if args.limit is not None:
wikitext = wikitext[: args.limit]
wikitext = Dataset.from_dict(wikitext)
logging.info(f"Limiting the dataset to {args.limit} entries.")
dataloader = torch.utils.data.DataLoader(
wikitext, num_workers=0, batch_size=args.batch_size
)
logging.info("Training the tokenizer.")
tokenizer = get_unigram_tokenizer()
trainer = get_unigram_trainer(args.vocab_size)
tokenizer.train_from_iterator(dataloader, trainer=trainer)
logging.info("Tokenizer training complete!")
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", cls_token_id),
("[SEP]", sep_token_id),
],
)
tokenizer.decoder = decoders.Metaspace()
if args.export_to_hub:
logging.info("Exporting the trained tokenzier to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("sayakpaul/unigram-tokenizer-wikitext")
if __name__ == "__main__":
args = parse_args()
main(args)