|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
from pathlib import Path |
|
from typing import Dict, List, Union |
|
|
|
import torch.cuda |
|
|
|
from nemo.collections.nlp.models import PunctuationCapitalizationLexicalAudioModel, PunctuationCapitalizationModel |
|
|
|
|
|
""" |
|
This script is for restoring punctuation and capitalization. |
|
|
|
Usage example: |
|
|
|
python punctuate_capitalize.py \ |
|
--input_manifest <PATH/TO/INPUT/MANIFEST> \ |
|
--output_manifest <PATH/TO/OUTPUT/MANIFEST> |
|
|
|
Usage example for lexical audio model: |
|
python punctuate_capitalize.py \ |
|
--input_manifest <PATH/TO/INPUT/MANIFEST> \ |
|
--output_manifest <PATH/TO/OUTPUT/MANIFEST> \ |
|
--use_audio |
|
|
|
|
|
<PATH/TO/INPUT/MANIFEST> is a path to NeMo ASR manifest. Usually it is an output of |
|
NeMo/examples/asr/transcribe_speech.py but can be a manifest with 'text' key. Alternatively you can use |
|
--input_text parameter for passing text for inference. |
|
<PATH/TO/OUTPUT/MANIFEST> is a path to NeMo ASR manifest into which script output will be written. Alternatively |
|
you can use parameter --output_text. |
|
|
|
For more details on this script usage look in argparse help. |
|
""" |
|
|
|
|
|
def get_args() -> argparse.Namespace: |
|
default_model_parameter = "pretrained_name" |
|
default_model = "punctuation_en_bert" |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
description="The script is for restoring punctuation and capitalization in text or text and audio. To use text and audio use '--use_audio'. Long strings are split into " |
|
"segments of length `--max_seq_length`. `--max_seq_length` is the length which includes [CLS] and [SEP] " |
|
"tokens. If `--use_audio` is set, samples with texts longer than `--max_seq_length` will be ignored. Parameter `--step` controls segments overlapping. `--step` is a distance between beginnings of " |
|
"consequent segments. Model outputs for tokens near the borders of tensors are less accurate and can be " |
|
"discarded before final predictions computation. Parameter `--margin` is number of discarded outputs near " |
|
"segments borders. Probabilities of tokens in overlapping parts of segments multiplied before selecting the " |
|
"best prediction. Default values of parameters `--max_seq_length`, `--step`, and `--margin` are optimal for " |
|
"IWSLT 2019 test dataset.", |
|
) |
|
parser.add_argument( |
|
'--use_audio', |
|
required=False, |
|
action="store_true", |
|
help="If set `PunctuationCapitalizationLexicalAudioModel` will be used for inference", |
|
) |
|
input_ = parser.add_mutually_exclusive_group(required=True) |
|
input_.add_argument( |
|
"--input_manifest", |
|
"-m", |
|
type=Path, |
|
help="Path to the file with NeMo manifest which needs punctuation and capitalization. If the first element " |
|
"of manifest contains key 'pred_text', 'pred_text' values are passed for tokenization. Otherwise 'text' " |
|
"values are passed for punctuation and capitalization. Exactly one parameter of `--input_manifest` and " |
|
"`--input_text` should be provided.", |
|
) |
|
input_.add_argument( |
|
"--input_text", |
|
"-t", |
|
type=Path, |
|
help="Path to file with text which needs punctuation and capitalization. Exactly one parameter of " |
|
"`--input_manifest` and `--input_text` should be provided.", |
|
) |
|
parser.add_argument( |
|
'--audio_file', |
|
required=False, |
|
type=Path, |
|
help="Path to file with paths to audio. One path per row. Required if '--input_text' provided. Else 'audio_filepath' from manifest will be used.", |
|
) |
|
output = parser.add_mutually_exclusive_group(required=True) |
|
output.add_argument( |
|
"--output_manifest", |
|
"-M", |
|
type=Path, |
|
help="Path to output NeMo manifest. Text with restored punctuation and capitalization will be saved in " |
|
"'pred_text' elements if 'pred_text' key is present in the input manifest. Otherwise text with restored " |
|
"punctuation and capitalization will be saved in 'text' elements. Exactly one parameter of `--output_manifest` " |
|
"and `--output_text` should be provided.", |
|
) |
|
output.add_argument( |
|
"--output_text", |
|
"-T", |
|
type=Path, |
|
help="Path to file with text with restored punctuation and capitalization. Exactly one parameter of " |
|
"`--output_manifest` and `--output_text` should be provided.", |
|
) |
|
model = parser.add_mutually_exclusive_group(required=False) |
|
model.add_argument( |
|
"--pretrained_name", |
|
"-p", |
|
help=f"The name of NGC pretrained model. No more than one of parameters `--pretrained_name`, `--model_path`" |
|
f"should be provided. If neither of parameters `--pretrained_name` and `--model_path` are provided, then the " |
|
f"script is run with `--{default_model_parameter}={default_model}`.", |
|
choices=[m.pretrained_model_name for m in PunctuationCapitalizationModel.list_available_models()] |
|
+ [m.pretrained_model_name for m in PunctuationCapitalizationLexicalAudioModel.list_available_models()], |
|
) |
|
model.add_argument( |
|
"--model_path", |
|
"-P", |
|
type=Path, |
|
help=f"Path to .nemo checkpoint of punctuation and capitalization model. No more than one of parameters " |
|
f"`--pretrained_name` and `--model_path` should be provided. If neither of parameters `--pretrained_name` and " |
|
f"`--model_path` are provided, then the script is run with `--{default_model_parameter}={default_model}`.", |
|
) |
|
parser.add_argument( |
|
"--max_seq_length", |
|
"-L", |
|
type=int, |
|
default=64, |
|
help="Length of segments into which queries are split. `--max_seq_length` includes [CLS] and [SEP] tokens.", |
|
) |
|
parser.add_argument( |
|
"--step", |
|
"-s", |
|
type=int, |
|
default=8, |
|
help="Relative shift of consequent segments into which long queries are split. Long queries are split into " |
|
"segments which can overlap. Parameter `step` controls such overlapping. Imagine that queries are " |
|
"tokenized into characters, `max_seq_length=5`, and `step=2`. In such a case query 'hello' is tokenized " |
|
"into segments `[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]`.", |
|
) |
|
parser.add_argument( |
|
"--margin", |
|
"-g", |
|
type=int, |
|
default=16, |
|
help="A number of subtokens in the beginning and the end of segments which output probabilities are not used " |
|
"for prediction computation. The first segment does not have left margin and the last segment does not have " |
|
"right margin. For example, if input sequence is tokenized into characters, `max_seq_length=5`, `step=1`, " |
|
"and `margin=1`, then query 'hello' will be tokenized into segments `[['[CLS]', 'h', 'e', 'l', '[SEP]'], " |
|
"['[CLS]', 'e', 'l', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]`. These segments are passed to the " |
|
"model. Before final predictions computation, margins are removed. In the next list, subtokens which logits " |
|
"are not used for final predictions computation are marked with asterisk: `[['[CLS]'*, 'h', 'e', 'l'*, " |
|
"'[SEP]'*], ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]`.", |
|
) |
|
parser.add_argument( |
|
"--batch_size", "-b", type=int, default=128, help="Number of segments which are processed simultaneously.", |
|
) |
|
parser.add_argument( |
|
"--save_labels_instead_of_text", |
|
"-B", |
|
action="store_true", |
|
help="If this option is set, then punctuation and capitalization labels are saved instead text with restored " |
|
"punctuation and capitalization. Labels are saved in format described here " |
|
"https://docs.nvidia.com/deeplearning/nemo/" |
|
"user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
"-d", |
|
choices=['cpu', 'cuda'], |
|
help="Which device to use. If device is not set and CUDA is available, then GPU will be used. If device is " |
|
"not set and CUDA is not available, then CPU is used.", |
|
) |
|
parser.add_argument( |
|
"--sample_rate", |
|
type=int, |
|
default=16000, |
|
help="Target sample rate for audios if `--use_audio` was passed", |
|
required=False, |
|
) |
|
args = parser.parse_args() |
|
if args.input_manifest is None and args.output_manifest is not None: |
|
parser.error("--output_manifest requires --input_manifest") |
|
if args.use_audio and (args.input_manifest is None and args.audio_file is None): |
|
parser.error("--use_audio and --input_text require --audio_file") |
|
if args.pretrained_name is None and args.model_path is None: |
|
setattr(args, default_model_parameter, default_model) |
|
for name in ["input_manifest", "input_text", "output_manifest", "output_text", "model_path", "audio_file"]: |
|
if getattr(args, name) is not None: |
|
setattr(args, name, getattr(args, name).expanduser()) |
|
return args |
|
|
|
|
|
def load_manifest(manifest: Path) -> List[Dict[str, Union[str, float]]]: |
|
result = [] |
|
with manifest.open() as f: |
|
for i, line in enumerate(f): |
|
data = json.loads(line) |
|
result.append(data) |
|
return result |
|
|
|
|
|
def main() -> None: |
|
args = get_args() |
|
if args.pretrained_name is None: |
|
model = ( |
|
PunctuationCapitalizationModel.restore_from(args.model_path) |
|
if not args.use_audio |
|
else PunctuationCapitalizationLexicalAudioModel.restore_from(args.model_path) |
|
) |
|
else: |
|
model = ( |
|
PunctuationCapitalizationModel.from_pretrained(args.pretrained_name) |
|
if not args.use_audio |
|
else PunctuationCapitalizationLexicalAudioModel.restore_from(args.model_path) |
|
) |
|
if args.device is None: |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
else: |
|
model = model.cpu() |
|
else: |
|
model = model.to(args.device) |
|
if args.input_manifest is None: |
|
texts = [] |
|
audios = [] |
|
with args.input_text.open() as f: |
|
for line in f: |
|
texts.append(line.strip()) |
|
if args.use_audio: |
|
with args.audio_file.open() as f: |
|
for line in f: |
|
audios.append(line.strip()) |
|
else: |
|
manifest = load_manifest(args.input_manifest) |
|
text_key = "pred_text" if "pred_text" in manifest[0] else "text" |
|
texts = [] |
|
audios = [] |
|
for item in manifest: |
|
texts.append(item[text_key]) |
|
if args.use_audio: |
|
audios.append(item["audio_filepath"]) |
|
if args.use_audio: |
|
processed_texts = model.add_punctuation_capitalization( |
|
texts, |
|
batch_size=args.batch_size, |
|
max_seq_length=args.max_seq_length, |
|
step=args.step, |
|
margin=args.margin, |
|
return_labels=args.save_labels_instead_of_text, |
|
audio_queries=audios, |
|
target_sr=args.sample_rate, |
|
) |
|
else: |
|
processed_texts = model.add_punctuation_capitalization( |
|
texts, |
|
batch_size=args.batch_size, |
|
max_seq_length=args.max_seq_length, |
|
step=args.step, |
|
margin=args.margin, |
|
return_labels=args.save_labels_instead_of_text, |
|
) |
|
if args.output_manifest is None: |
|
args.output_text.parent.mkdir(exist_ok=True, parents=True) |
|
with args.output_text.open('w') as f: |
|
for t in processed_texts: |
|
f.write(t + '\n') |
|
else: |
|
args.output_manifest.parent.mkdir(exist_ok=True, parents=True) |
|
with args.output_manifest.open('w') as f: |
|
for item, t in zip(manifest, processed_texts): |
|
item[text_key] = t |
|
f.write(json.dumps(item) + '\n') |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|