File size: 3,270 Bytes
256477b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Usage:
    python train_unigram.py --export_to_hub

Note that you'd need to execute `huggingface-cli login` before if you passed export_to_hub.

Reference:
    https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb
"""

import argparse
import logging

import datasets
import torch
from datasets import Dataset
from tokenizers import (
    Tokenizer,
    decoders,
    normalizers,
    pre_tokenizers,
    processors,
)
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from transformers import AlbertTokenizerFast


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train a unigram tokenizer on the wikitext dataset."
    )
    parser.add_argument(
        "-bs",
        "--batch-size",
        type=int,
        default=1000,
        help="Batch size during training.",
    )
    parser.add_argument(
        "-vs",
        "--vocab-size",
        type=int,
        default=10000,
        help="Size of the desired vocabulary.",
    )
    parser.add_argument(
        "--limit",
        default=None,
        type=int,
        help="Limit the number of shards (used for debugging).",
    )
    parser.add_argument(
        "--export_to_hub",
        action="store_true",
    )

    args = parser.parse_args()
    return args


def get_unigram_tokenizer() -> Tokenizer:
    tokenizer = Tokenizer(Unigram())
    tokenizer.normalizer = normalizers.Sequence(
        [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]
    )
    tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
    return tokenizer


def get_unigram_trainer(vocab_size: int) -> UnigramTrainer:
    trainer = UnigramTrainer(
        unk_token="<unk>",
        special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
        vocab_size=vocab_size,
    )
    return trainer


def main(args):
    wikitext = datasets.load_dataset(
        "wikitext", "wikitext-103-raw-v1", split="train"
    )

    if args.limit is not None:
        wikitext = wikitext[: args.limit]
        wikitext = Dataset.from_dict(wikitext)
        logging.info(f"Limiting the dataset to {args.limit} entries.")

    dataloader = torch.utils.data.DataLoader(
        wikitext, num_workers=0, batch_size=args.batch_size
    )
    logging.info("Training the tokenizer.")
    tokenizer = get_unigram_tokenizer()
    trainer = get_unigram_trainer(args.vocab_size)
    tokenizer.train_from_iterator(dataloader, trainer=trainer)
    logging.info("Tokenizer training complete!")

    cls_token_id = tokenizer.token_to_id("[CLS]")
    sep_token_id = tokenizer.token_to_id("[SEP]")
    tokenizer.post_processor = processors.TemplateProcessing(
        single="[CLS]:0 $A:0 [SEP]:0",
        pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
        special_tokens=[
            ("[CLS]", cls_token_id),
            ("[SEP]", sep_token_id),
        ],
    )
    tokenizer.decoder = decoders.Metaspace()

    if args.export_to_hub:
        logging.info("Exporting the trained tokenzier to Hub.")
        new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
        new_tokenizer.push_to_hub("sayakpaul/unigram-tokenizer-wikitext")


if __name__ == "__main__":
    args = parse_args()
    main(args)