NeMo / examples /nlp /token_classification /data /create_punctuation_capitalization_tarred_dataset.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import multiprocessing as mp
from pathlib import Path
from nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset import (
DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME,
DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME,
METADATA_CAPIT_LABEL_VOCAB_KEY,
METADATA_PUNCT_LABEL_VOCAB_KEY,
build_label_ids_from_list_of_labels,
check_labels_for_being_unique_before_building_label_ids,
check_tar_file_prefix,
create_tarred_dataset,
)
"""
A tarred dataset allows to train on large amounts without storing it all into memory simultaneously. In case of
punctuation and capitalization model, tarred dataset is a directory which contains metadata file, tar files with
batches, punct_label_vocab.csv and capit_label_vocab.csv files.
A metadata file is a JSON file with 4 fields: 'num_batches', 'tar_files', 'punct_label_vocab_file',
'capit_label_vocab_file'. 'num_batches' (int) is a total number of batches in tarred dataset. 'tar_files' is a list of
paths to tar files relative to directory containing the metadata file. 'punct_label_vocab_file' and
'capit_label_vocab_file' are paths to .csv files containing all unique punctuation and capitalization labels. Each
label in these files is written in a separate line. The first labels in both files are equal and serve for padding and
as neutral labels.
Every tar file contains objects written using `webdataset.TarWriter`. Each object is a dictionary with two items:
'__key__' and 'batch.pyd'. '__key__' is a name of a batch and 'batch.pyd' is a pickled dictionary which contains
'input_ids', 'subtokens_mask', 'punct_labels', 'capit_labels'. 'input_ids' is an array containing ids of source tokens,
'subtokens_mask' is a boolean array showing first tokens in words, 'punct_labels' and 'capit_labels' are arrays with
ids of labels. Metadata file should be passed to constructor of
`nemo.collections.nlp.data.token_classification.PunctuationCapitalizationTarredDataset` and the instance of
the class will handle iteration and constructing masks and token types for BERT model.
Example of usage:
python create_punctuation_capitalization_tarred_dataset.py \
--text <PATH/TO/TEXT/FILE> \
--labels <PATH/TO/LABELS/FILE> \
--output_dir <PATH/TO/OUTPUT/DIR> \
--lines_per_dataset_fragment 10000 \
--tokens_in_batch 8000 \
--num_batches_per_tarfile 5 \
--tokenizer_name char \
--vocab_file <PATH_TO_CHAR_TOKENIZER_VOCABULARY>
"""
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=f"A tarred dataset allows to train on large amounts without storing it all into memory "
f"simultaneously. In case of punctuation and capitalization model, tarred dataset is a directory which "
f"contains metadata file, tar files with batches, {DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME} and "
f"{DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME} files. A metadata file is a JSON file with 4 fields: 'num_batches', "
f"'tar_files', '{METADATA_PUNCT_LABEL_VOCAB_KEY}', '{METADATA_CAPIT_LABEL_VOCAB_KEY}'. 'num_batches' (int) is "
f"a total number of batches in tarred dataset. 'tar_files' is a list of paths to tar files relative "
f"to directory containing the metadata file. '{METADATA_PUNCT_LABEL_VOCAB_KEY}' and "
f"'{METADATA_CAPIT_LABEL_VOCAB_KEY}' are paths to .csv files containing all unique punctuation and "
f"capitalization labels. Each label in these files is written in a separate line. The first labels in both "
f"files are equal and serve for padding and as neutral labels. Every tar file contains objects written "
f"using `webdataset.TarWriter`. Each object is a dictionary with two items: '__key__' and 'batch.pyd'. "
f"'__key__' is a name of a batch and 'batch.pyd' is a pickled dictionary which contains 'input_ids', "
f"'subtokens_mask', 'punct_labels', 'capit_labels'. 'input_ids' is an array containing ids of source tokens, "
f"'subtokens_mask' is a boolean array showing first tokens in words, 'punct_labels' and 'capit_labels' are "
f"arrays with ids of labels. Metadata file should be passed to constructor of "
"`nemo.collections.nlp.data.token_classification.PunctuationCapitalizationTarredDataset` and the instance of "
"the class will handle iteration and constructing masks and token types for BERT model.",
)
parser.add_argument(
"--text",
"-t",
help="Path to source lowercased text without punctuation. Number of lines in `--text` file has to be equal "
"to number of lines in `--labels` file.",
type=Path,
required=True,
)
parser.add_argument(
"--audio_file",
type=Path,
required=False,
help="Path to source file which contains paths to audio one path per line. "
"Number of lines in `--audio_file` has to be equal to number of lines in `--labels` file",
)
parser.add_argument(
"--use_audio",
required=False,
action="store_true",
help="If set to `True` script creates lexical audio dataset which can be used with `PunctuationCapitalizationLexicalAudioModel`.",
)
parser.add_argument(
"--sample_rate",
type=int,
required=False,
help="Target sample rate of audios. Can be used for downsampling or upsampling.",
)
parser.add_argument(
"--labels",
"-L",
type=Path,
required=True,
help="Path to file with labels in the format described here "
"https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#"
"nemo-data-format . Number of lines in `--labels` file has to be equal to the number of lines in `--text` "
"file.",
)
parser.add_argument(
"--output_dir",
"-o",
type=Path,
required=True,
help="Path to directory where .tar files, metadata file, label id files are stored.",
)
parser.add_argument(
"--max_seq_length",
"-s",
type=int,
default=512,
help="Maximum number of subtokens in an input sequence. A source sequence which contain too many subtokens are "
"clipped to `--max_seq_length - 2` subtokens and then [CLS] token is prepended to the clipped sequence and "
"[SEP] token is appended to the clipped sequence. The clipping is performed via removal of subtokens in the "
"end of a source sequence.",
)
parser.add_argument(
"--tokens_in_batch",
"-b",
type=int,
default=15000,
help="Maximum number of tokens in a batch including [CLS], [SEP], [UNK], and [PAD] tokens. Before packing into "
"batches source sequences are sorted by number of tokens in order to reduce number of pad tokens. So the "
"number of sequences in a batch may be different.",
)
parser.add_argument(
"--lines_per_dataset_fragment",
type=int,
default=10 ** 6,
help="A number of lines processed by one worker during creation of tarred dataset. A worker tokenizes "
"`--lines_per_dataset_fragment` lines and keeps in RAM tokenized text labels before packing them into "
"batches. Reducing `--lines_per_dataset_fragment` leads to reducing of the amount of memory required by this "
"script.",
)
parser.add_argument(
"--num_batches_per_tarfile",
type=int,
default=1000,
help="A number of batches saved in a tar file. If you increase `--num_batches_per_tarfile`, then there will "
"be less tar files in the dataset. There cannot be less then `--num_batches_per_tarfile` batches in a tar "
"file, and all excess batches are removed. Maximum number of discarded batches is "
"`--num_batches_per_tarfile - 1`.",
)
parser.add_argument(
"--tokenizer_name",
"-T",
default="bert-base-uncased",
help="Name of the tokenizer used for tokenization of source sequences. Possible options are 'sentencepiece', "
"'word', 'char', HuggingFace tokenizers. For more options see function "
"`nemo.collections.nlp.modules.common.get_tokenizer`. The tokenizer has to have properties `cls_id`, "
"`pad_id`, `sep_id`, `unk_id`.",
)
parser.add_argument(
"--tokenizer_model", "-m", type=Path, help="Path to tokenizer model required for 'sentencepiece' tokenizer."
)
parser.add_argument(
"--vocab_file",
"-v",
type=Path,
help="Path to vocabulary file which can be used in 'word', 'char', and HuggingFace tokenizers.",
)
parser.add_argument(
"--merges_file", "-M", type=Path, help="Path to merges file which can be used in HuggingFace tokenizers."
)
parser.add_argument(
"--special_token_names",
"-n",
nargs="+",
help="Names of special tokens which may be passed to constructors of 'char', 'word', 'sentencepiece', and "
"HuggingFace tokenizers.",
)
parser.add_argument(
"--special_token_values",
"-V",
nargs="+",
help="Values of special tokens which may be passed to constructors of 'char', 'word', 'sentencepiece', and "
"HuggingFace tokenizers.",
)
parser.add_argument(
"--use_fast_tokenizer", "-f", action="store_true", help="Whether to use fast HuggingFace tokenizer."
)
parser.add_argument(
"--pad_label",
"-P",
default='O',
help="Pad label both for punctuation and capitalization. This label is also is used for marking words which "
"do not need punctuation and capitalization. It is also a neutral label used for marking words which do "
"not require punctuation and capitalization.",
)
punct = parser.add_mutually_exclusive_group(required=False)
punct.add_argument(
"--punct_labels",
"-p",
nargs="+",
help="All punctuation labels EXCEPT PAD LABEL. Punctuation labels are strings separated by spaces. "
"Alternatively you can use parameter `--punct_label_vocab_file`. If none of parameters `--punct_labels` "
"and `--punct_label_vocab_file` are provided, then punctuation label ids will be inferred from `--labels` "
"file.",
)
punct.add_argument(
"--punct_label_vocab_file",
type=Path,
help="A path to file with punctuation labels. These labels include pad label. Pad label has to be the first "
"label in the file. Each label is written on separate line. Alternatively you can use `--punct_labels` "
"parameter. If none of parameters `--punct_labels` and `--punct_label_vocab_file` are provided, then "
"punctuation label ids will be inferred from `--labels` file.",
)
capit = parser.add_mutually_exclusive_group(required=False)
capit.add_argument(
"--capit_labels",
"-c",
nargs="+",
help="All capitalization labels EXCEPT PAD LABEL. Capitalization labels are strings separated by spaces. "
"Alternatively you can use parameter `--capit_label_vocab_file`. If none of parameters `--capit_labels` "
"and `--capit_label_vocab_file` are provided, then capitalization label ids will be inferred from `--labels` "
"file.",
)
capit.add_argument(
"--capit_label_vocab_file",
type=Path,
help="A path to file with capitalization labels. These labels include pad label. Pad label has to be the "
"first label in the file. Each label is written on separate line. Alternatively you can use `--capit_labels` "
"parameter. If none of parameters `--capit_labels` and `--capit_label_vocab_file` are provided, then "
"capitalization label ids will be inferred from `--labels` file.",
)
parser.add_argument(
"--tar_file_prefix",
"-x",
default="punctuation_capitalization",
help="A string from which tar file names start. It can contain only characters 'A-Z', 'a-z', '0-9', '_', '-', "
"'.'.",
)
parser.add_argument(
"--n_jobs",
"-j",
type=int,
default=mp.cpu_count(),
help="Number of workers for creating tarred dataset. By default it is equal to the number of CPU cores.",
)
args = parser.parse_args()
for name in [
"text",
"labels",
"output_dir",
"tokenizer_model",
"vocab_file",
"merges_file",
"punct_label_vocab_file",
"capit_label_vocab_file",
]:
if getattr(args, name) is not None:
setattr(args, name, getattr(args, name).expanduser())
if args.special_token_names is not None or args.special_token_values is not None:
if args.special_token_names is None:
parser.error(
"If you provide parameter `--special_token_values` you have to provide parameter "
"`--special_token_names`."
)
if args.special_token_values is None:
parser.error(
"If you provide parameter `--special_token_names` you have to provide parameter "
"`--special_token_values`."
)
if len(args.special_token_names) != len(args.special_token_values):
parser.error(
f"Parameters `--special_token_names` and `--special_token_values` have to have equal number of values "
f"whereas parameter `--special_token_names` has {len(args.special_token_names)} values and parameter "
f"`--special_token_values` has {len(args.special_token_values)} values."
)
if len(set(args.special_token_names)) != len(args.special_token_names):
for i in range(len(args.special_token_names) - 1):
if args.special_token_names[i] in args.special_token_names[i + 1 :]:
parser.error(
f"Values of parameter `--special_token_names` has to be unique. Found duplicate value "
f"'{args.special_token_names[i]}'."
)
if args.punct_labels is not None:
check_labels_for_being_unique_before_building_label_ids(
args.pad_label, args.punct_labels, '--pad_label', '--punct_labels', parser.error
)
check_labels_for_being_unique_before_building_label_ids(
args.pad_label, args.capit_labels, '--pad_label', '--capit_labels', parser.error
)
check_tar_file_prefix(args.tar_file_prefix, parser.error, '--tar_file_prefix')
return args
def main() -> None:
args = get_args()
if args.special_token_names is None:
special_tokens = None
else:
special_tokens = dict(zip(args.special_token_names, args.special_token_values))
if args.punct_labels is not None:
punct_label_ids = build_label_ids_from_list_of_labels(args.pad_label, args.punct_labels)
else:
punct_label_ids = None
if args.capit_labels is not None:
capit_label_ids = build_label_ids_from_list_of_labels(args.pad_label, args.capit_labels)
else:
capit_label_ids = None
create_tarred_dataset(
args.text,
args.labels,
args.output_dir,
args.max_seq_length,
args.tokens_in_batch,
args.lines_per_dataset_fragment,
args.num_batches_per_tarfile,
args.tokenizer_name,
tokenizer_model=args.tokenizer_model,
vocab_file=args.vocab_file,
merges_file=args.merges_file,
special_tokens=special_tokens,
use_fast_tokenizer=args.use_fast_tokenizer,
pad_label=args.pad_label,
punct_label_ids=punct_label_ids,
capit_label_ids=capit_label_ids,
punct_label_vocab_file=args.punct_label_vocab_file,
capit_label_vocab_file=args.capit_label_vocab_file,
tar_file_prefix=args.tar_file_prefix,
n_jobs=args.n_jobs,
audio_file=args.audio_file,
sample_rate=args.sample_rate,
use_audio=args.use_audio,
)
if __name__ == "__main__":
main()