#!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script for training a Unigram tokenizer.""" import argparse import logging import datasets from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import Unigram from tokenizers.trainers import UnigramTrainer from transformers import AlbertTokenizerFast logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.") parser.add_argument( "--dataset_name", type=str, default="wikitext", help="Name of the training. Explore datasets at: hf.co/datasets.", ) parser.add_argument( "--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset." ) parser.add_argument( "--batch_size", type=int, default=1000, help="Batch size during training.", ) parser.add_argument( "--vocab_size", type=int, default=10048, help="Size of the desired vocabulary.", ) parser.add_argument( "--limit", default=None, type=int, help="Limit the number of shards (used for debugging).", ) parser.add_argument( "--export_to_hub", action="store_true", ) args = parser.parse_args() return args def main(args): dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train") if args.limit is not None: max_train_samples = min(len(dataset), args.limit) dataset = dataset.select(range(max_train_samples)) logger.info(f"Limiting the dataset to {args.limit} entries.") def batch_iterator(): for i in range(0, len(dataset), args.batch_size): yield dataset[i : i + args.batch_size]["text"] # Prepare the tokenizer. tokenizer = Tokenizer(Unigram()) tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]) tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() # Prepare the trainer. trainer = UnigramTrainer( unk_token="", special_tokens=["[CLS]", "[SEP]", "", "", "[MASK]"], vocab_size=args.vocab_size, ) logger.info("Training the tokenizer.") tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) logger.info("Tokenizer training complete!") cls_token_id = tokenizer.token_to_id("[CLS]") sep_token_id = tokenizer.token_to_id("[SEP]") tokenizer.post_processor = processors.TemplateProcessing( single="[CLS]:0 $A:0 [SEP]:0", pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", special_tokens=[ ("[CLS]", cls_token_id), ("[SEP]", sep_token_id), ], ) tokenizer.decoder = decoders.Metaspace() if args.export_to_hub: logger.info("Exporting the trained tokenzier to Hub.") new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) new_tokenizer.push_to_hub("unigram-tokenizer-dataset") if __name__ == "__main__": args = parse_args() main(args)