gpt2-medium-persian / src /train_tokenizer.py
m3hrdadfi's picture
Add runner, fix some bugs
31bf2aa
raw
history blame
No virus
5.46 kB
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union, Any
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
HfArgumentParser,
)
from data_utils import (
filter_by_lang_regex,
filter_by_num_tokens,
filter_by_num_sents,
filter_by_adv,
normalizer
)
logger = logging.getLogger(__name__)
@dataclass
class TokenizerArguments:
"""
Arguments to which tokenizer we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
special_tokens: Optional[str] = field(
default=None,
metadata={"help": "The list of special tokens that you want to add in your training."}
)
vocab_size: Optional[int] = field(
default=56000,
metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
)
min_frequency: Optional[int] = field(
default=2,
metadata={"help": "The minimum frequency a pair should have in order to be merged"}
)
show_progress: Optional[bool] = field(
default=True,
metadata={"help": "Whether to show progress bars while training"}
)
def __post_init__(self):
if self.special_tokens is None:
special_tokens = [
"<s>", "<pad>", "</s>", "<unk>", "<mask>",
"<|endoftext|>", "<|startoftext|>",
"<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
]
special_tokens += [f"[U{i}]" for i in range(1, 21)]
else:
special_tokens = list(self.special_tokens.split(","))
self.special_tokens = special_tokens
if self.dataset_name is None and self.train_file is None:
raise ValueError("Need either a dataset name or a training file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
def main():
parser = HfArgumentParser([TokenizerArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
tokenizer_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Training tokenizer")
if tokenizer_args.dataset_name is not None:
raw_dataset = load_dataset(
tokenizer_args.dataset_name,
tokenizer_args.dataset_config_name,
cache_dir=tokenizer_args.cache_dir,
split="train"
)
else:
data_files = {"train": tokenizer_args.train_file}
extension = tokenizer_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
raw_dataset = load_dataset(
extension,
data_files=data_files,
delimiter="\t",
cache_dir=tokenizer_args.cache_dir,
)
logger.info("Preprocessing the dataset")
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
dataset = dataset.map(normalizer)
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
tokenizer = ByteLevelBPETokenizer()
def batch_iterative(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
tokenizer.train_from_iterator(
batch_iterative(),
vocab_size=tokenizer_args.vocab_size,
special_tokens=tokenizer_args.special_tokens,
min_frequency=tokenizer_args.min_frequency,
show_progress=tokenizer_args.show_progress,
)
logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
os.makedirs(tokenizer_args.output_dir, exist_ok=True)
tokenizer.save_model(tokenizer_args.output_dir)
tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
if __name__ == '__main__':
main()