Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Script for training a Unigram tokenizer.""" | |
| import argparse | |
| import logging | |
| import datasets | |
| from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors | |
| from tokenizers.models import Unigram | |
| from tokenizers.trainers import UnigramTrainer | |
| from transformers import AlbertTokenizerFast | |
| logger = logging.getLogger(__name__) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.") | |
| parser.add_argument( | |
| "--dataset_name", | |
| type=str, | |
| default="wikitext", | |
| help="Name of the training. Explore datasets at: hf.co/datasets.", | |
| ) | |
| parser.add_argument( | |
| "--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset." | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=1000, | |
| help="Batch size during training.", | |
| ) | |
| parser.add_argument( | |
| "--vocab_size", | |
| type=int, | |
| default=10048, | |
| help="Size of the desired vocabulary.", | |
| ) | |
| parser.add_argument( | |
| "--limit", | |
| default=None, | |
| type=int, | |
| help="Limit the number of shards (used for debugging).", | |
| ) | |
| parser.add_argument( | |
| "--export_to_hub", | |
| action="store_true", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def main(args): | |
| dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train") | |
| if args.limit is not None: | |
| max_train_samples = min(len(dataset), args.limit) | |
| dataset = dataset.select(range(max_train_samples)) | |
| logger.info(f"Limiting the dataset to {args.limit} entries.") | |
| def batch_iterator(): | |
| for i in range(0, len(dataset), args.batch_size): | |
| yield dataset[i : i + args.batch_size]["text"] | |
| # Prepare the tokenizer. | |
| tokenizer = Tokenizer(Unigram()) | |
| tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]) | |
| tokenizer.pre_tokenizer = pre_tokenizers.Metaspace() | |
| # Prepare the trainer. | |
| trainer = UnigramTrainer( | |
| unk_token="<unk>", | |
| special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"], | |
| vocab_size=args.vocab_size, | |
| ) | |
| logger.info("Training the tokenizer.") | |
| tokenizer.train_from_iterator(batch_iterator(), trainer=trainer) | |
| logger.info("Tokenizer training complete!") | |
| cls_token_id = tokenizer.token_to_id("[CLS]") | |
| sep_token_id = tokenizer.token_to_id("[SEP]") | |
| tokenizer.post_processor = processors.TemplateProcessing( | |
| single="[CLS]:0 $A:0 [SEP]:0", | |
| pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", | |
| special_tokens=[ | |
| ("[CLS]", cls_token_id), | |
| ("[SEP]", sep_token_id), | |
| ], | |
| ) | |
| tokenizer.decoder = decoders.Metaspace() | |
| if args.export_to_hub: | |
| logger.info("Exporting the trained tokenzier to Hub.") | |
| new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer) | |
| new_tokenizer.push_to_hub("unigram-tokenizer-dataset") | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |