import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union, Any
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
HfArgumentParser,
)
from data_utils import (
filter_by_lang_regex,
filter_by_num_tokens,
filter_by_num_sents,
filter_by_adv,
normalizer
)
logger = logging.getLogger(__name__)
@dataclass
class TokenizerArguments:
"""
Arguments to which tokenizer we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
special_tokens: Optional[str] = field(
default=None,
metadata={"help": "The list of special tokens that you want to add in your training."}
)
vocab_size: Optional[int] = field(
default=56000,
metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
)
min_frequency: Optional[int] = field(
default=2,
metadata={"help": "The minimum frequency a pair should have in order to be merged"}
)
show_progress: Optional[bool] = field(
default=True,
metadata={"help": "Whether to show progress bars while training"}
)
def __post_init__(self):
if self.special_tokens is None:
special_tokens = [
"", "", "", "", "",
"<|endoftext|>", "<|startoftext|>",
"", "", "", "", ""
]
special_tokens += [f"[U{i}]" for i in range(1, 21)]
else:
special_tokens = list(self.special_tokens.split(","))
self.special_tokens = special_tokens
if self.dataset_name is None and self.train_file is None:
raise ValueError("Need either a dataset name or a training file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
def main():
parser = HfArgumentParser([TokenizerArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
tokenizer_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Training tokenizer")
if tokenizer_args.dataset_name is not None:
raw_dataset = load_dataset(
tokenizer_args.dataset_name,
tokenizer_args.dataset_config_name,
cache_dir=tokenizer_args.cache_dir,
split="train"
)
else:
data_files = {"train": tokenizer_args.train_file}
extension = tokenizer_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
raw_dataset = load_dataset(
extension,
data_files=data_files,
delimiter="\t",
cache_dir=tokenizer_args.cache_dir,
)
logger.info("Preprocessing the dataset")
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
dataset = dataset.map(normalizer)
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
tokenizer = ByteLevelBPETokenizer()
def batch_iterative(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
tokenizer.train_from_iterator(
batch_iterative(),
vocab_size=tokenizer_args.vocab_size,
special_tokens=tokenizer_args.special_tokens,
min_frequency=tokenizer_args.min_frequency,
show_progress=tokenizer_args.show_progress,
)
logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
os.makedirs(tokenizer_args.output_dir, exist_ok=True)
tokenizer.save_model(tokenizer_args.output_dir)
tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
if __name__ == '__main__':
main()