""" Usage: python train_unigram.py --export_to_hub Note that you'd need to execute `huggingface-cli login` before if you passed export_to_hub. Reference: https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb """ import argparse import logging import datasets import torch from datasets import Dataset from tokenizers import ( Tokenizer, decoders, normalizers, pre_tokenizers, processors, ) from tokenizers.models import Unigram from tokenizers.trainers import UnigramTrainer from transformers import AlbertTokenizerFast def parse_args(): parser = argparse.ArgumentParser( description="Train a unigram tokenizer on the wikitext dataset." ) parser.add_argument( "-bs", "--batch-size", type=int, default=1000, help="Batch size during training.", ) parser.add_argument( "-vs", "--vocab-size", type=int, default=10000, help="Size of the desired vocabulary.", ) parser.add_argument( "--limit", default=None, type=int, help="Limit the number of shards (used for debugging).", ) parser.add_argument( "--export_to_hub", action="store_true", ) args = parser.parse_args() return args def get_unigram_tokenizer() -> Tokenizer: tokenizer = Tokenizer(Unigram()) tokenizer.normalizer = normalizers.Sequence( [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] ) tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() return tokenizer def get_unigram_trainer(vocab_size: int) -> UnigramTrainer: trainer = UnigramTrainer( unk_token="", special_tokens=["[CLS]", "[SEP]", "", "", "[MASK]"], vocab_size=vocab_size, ) return trainer def main(args): wikitext = datasets.load_dataset( "wikitext", "wikitext-103-raw-v1", split="train" ) if args.limit is not None: wikitext = wikitext[: args.limit] wikitext = Dataset.from_dict(wikitext) logging.info(f"Limiting the dataset to {args.limit} entries.") dataloader = torch.utils.data.DataLoader( wikitext, num_workers=0, batch_size=args.batch_size ) logging.info("Training the tokenizer.") tokenizer = get_unigram_tokenizer() trainer = get_unigram_trainer(args.vocab_size) tokenizer.train_from_iterator(dataloader, trainer=trainer) logging.info("Tokenizer training complete!") cls_token_id = tokenizer.token_to_id("[CLS]") sep_token_id = tokenizer.token_to_id("[SEP]") tokenizer.post_processor = processors.TemplateProcessing( single="[CLS]:0 $A:0 [SEP]:0", pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", special_tokens=[ ("[CLS]", cls_token_id), ("[SEP]", sep_token_id), ], ) tokenizer.decoder = decoders.Metaspace() if args.export_to_hub: logging.info("Exporting the trained tokenzier to Hub.") new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) new_tokenizer.push_to_hub("sayakpaul/unigram-tokenizer-wikitext") if __name__ == "__main__": args = parse_args() main(args)