camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import subprocess
import tempfile
from pathlib import Path
from typing import Tuple
from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.collections.asr.metrics.wer import word_error_rate_detail
from nemo.utils import logging
def run_asr_inference(cfg: DictConfig) -> DictConfig:
"""
Execute ASR inference based on input mode and parameters.
"""
if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name):
raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!")
if cfg.inference.mode == "offline":
cfg = run_offline_inference(cfg)
elif cfg.inference.mode == "chunked":
if (
"total_buffer_in_secs" not in cfg.inference
or "chunk_len_in_secs" not in cfg.inference
or not cfg.inference.total_buffer_in_secs
or not cfg.inference.chunk_len_in_secs
):
raise ValueError(f"Please specify both total_buffer_in_secs and chunk_len_in_secs for chunked inference")
cfg = run_chunked_inference(cfg)
elif cfg.inference.mode == "offline_by_chunked":
# When use Conformer to transcribe long audio sample, we could probably encounter CUDA out of memory issue.
# Here we use offline_by_chunked mode to simulate offline mode for Conformer.
# And we specify default total_buffer_in_secs=22 and chunk_len_in_secs=20 to avoid above problem.
OmegaConf.set_struct(cfg, True)
if 'total_buffer_in_secs' not in cfg.inference or not cfg.inference.total_buffer_in_secs:
with open_dict(cfg):
cfg.inference.total_buffer_in_secs = 22
logging.info(
f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.total_buffer_in_secs}"
)
if 'chunk_len_in_secs' not in cfg.inference or not cfg.inference.chunk_len_in_secs:
with open_dict(cfg):
cfg.inference.chunk_len_in_secs = 20
logging.info(
f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.chunk_len_in_secs}"
)
cfg = run_chunked_inference(cfg)
else:
raise ValueError(f"inference could only be offline or chunked, but got {cfg.inference.mode}")
return cfg
def run_chunked_inference(cfg: DictConfig) -> DictConfig:
if "output_filename" not in cfg or not cfg.output_filename:
if cfg.model_path:
model_name = Path(cfg.model_path).stem
else:
model_name = cfg.pretrained_name
dataset_name = Path(cfg.test_ds.manifest_filepath).stem
mode_name = (
cfg.inference.mode
+ "B"
+ str(cfg.inference.total_buffer_in_secs)
+ "C"
+ str(cfg.inference.chunk_len_in_secs)
)
OmegaConf.set_struct(cfg, True)
with open_dict(cfg):
cfg.output_filename = model_name + "-" + dataset_name + "-" + mode_name + ".json"
script_path = (
Path(__file__).parents[2]
/ "examples"
/ "asr"
/ "asr_chunked_inference"
/ "ctc"
/ "speech_to_text_buffered_infer_ctc.py"
)
if (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or (
cfg.model_path and 'transducer' in cfg.model_path
):
script_path = (
Path(__file__).parents[2]
/ "examples"
/ "asr"
/ "asr_chunked_inference"
/ "rnnt"
/ "speech_to_text_buffered_infer_rnnt.py"
)
subprocess.run(
f"python {script_path} "
f"model_path={cfg.model_path} "
f"pretrained_name={cfg.pretrained_name} "
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"random_seed={cfg.random_seed} "
f"batch_size={cfg.test_ds.batch_size} "
f"chunk_len_in_secs={cfg.inference.chunk_len_in_secs} "
f"total_buffer_in_secs={cfg.inference.total_buffer_in_secs} "
f"model_stride={cfg.inference.model_stride} ",
shell=True,
check=True,
)
return cfg
def run_offline_inference(cfg: DictConfig) -> DictConfig:
if "output_filename" not in cfg or not cfg.output_filename:
if cfg.model_path:
model_name = Path(cfg.model_path).stem
else:
model_name = cfg.pretrained_name
dataset_name = Path(cfg.test_ds.manifest_filepath).stem
mode_name = cfg.inference.mode
OmegaConf.set_struct(cfg, True)
with open_dict(cfg):
cfg.output_filename = model_name + "-" + dataset_name + "-" + mode_name + ".json"
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f:
OmegaConf.save(cfg, f)
f.seek(0) # reset file pointer
script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py"
# If need to move other config such as decoding strategy, could either:
# 1) change TranscriptionConfig on top of the executed scripts such as transcribe_speech.py in examples/asr, or
# 2) add command as "rnnt_decoding.strategy=greedy_batch " to below script
subprocess.run(
f"python {script_path} "
f"model_path={cfg.model_path} "
f"pretrained_name={cfg.pretrained_name} "
f"dataset_manifest={cfg.test_ds.manifest_filepath} "
f"output_filename={cfg.output_filename} "
f"batch_size={cfg.test_ds.batch_size} "
f"random_seed={cfg.random_seed} "
f"eval_config_yaml={f.name} ",
shell=True,
check=True,
)
return cfg
def clean_label(_str: str, num_to_words: bool = True, langid="en") -> str:
"""
Remove unauthorized characters in a string, lower it and remove unneeded spaces
"""
replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→']
replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”']
replace_with_apos = [char for char in '‘’ʻ‘’‘']
_str = _str.strip()
_str = _str.lower()
for i in replace_with_blank:
_str = _str.replace(i, "")
for i in replace_with_space:
_str = _str.replace(i, " ")
for i in replace_with_apos:
_str = _str.replace(i, "'")
if num_to_words:
if langid == "en":
_str = convert_num_to_words(_str, langid="en")
else:
logging.info(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages! Skipping!"
)
ret = " ".join(_str.split())
return ret
def convert_num_to_words(_str: str, langid: str = "en") -> str:
"""
Convert digits to corresponding words. Note this is a naive approach and could be replaced with text normalization.
"""
if langid == "en":
num_to_words = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
_str = _str.strip()
words = _str.split()
out_str = ""
num_word = []
for word in words:
if word.isdigit():
num = int(word)
while num:
digit = num % 10
digit_word = num_to_words[digit]
num_word.append(digit_word)
num = int(num / 10)
if not (num):
num_str = ""
num_word = num_word[::-1]
for ele in num_word:
num_str += ele + " "
out_str += num_str + " "
num_word.clear()
else:
out_str += word + " "
out_str = out_str.strip()
else:
raise ValueError(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages!"
)
return out_str
def cal_write_wer(cfg: DictConfig, pred_text_attr_name: str = None) -> Tuple[DictConfig, dict]:
"""
Calculate wer, inserion, deletion and substitution rate based on groundtruth text and pred_text_attr_name (pred_text)
We use WER in function name as a convention, but it currently Error Rate (ER) support Word Error Rate (WER) and Character Error Rate (CER)
"""
samples = []
hyps = []
refs = []
with open(cfg.engine.output_filename, 'r') as fp:
for line in fp:
sample = json.loads(line)
if 'text' not in sample:
raise ValueError(
"ground-truth text does not present in manifest! Cannot calculate Word Error Rate. Exiting!"
)
if not pred_text_attr_name:
pred_text_attr_name = "pred_text"
hyp = sample[pred_text_attr_name]
ref = sample['text']
if cfg.analyst.metric_calculator.clean_groundtruth_text:
ref = clean_label(ref, langid=cfg.analyst.metric_calculator.langid)
wer, tokens, ins_rate, del_rate, sub_rate = word_error_rate_detail(
hypotheses=[hyp], references=[ref], use_cer=cfg.analyst.metric_calculator.use_cer
)
eval_metric = "wer"
if cfg.analyst.metric_calculator.use_cer:
eval_metric = "cer"
sample[eval_metric] = wer # evaluatin metric, could be word error rate of character error rate
sample['tokens'] = tokens # number of word/characters/tokens
sample['ins_rate'] = ins_rate # insertion error rate
sample['del_rate'] = del_rate # deletion error rate
sample['sub_rate'] = sub_rate # substitution error rate
samples.append(sample)
hyps.append(hyp)
refs.append(ref)
total_wer, total_tokens, total_ins_rate, total_del_rate, total_sub_rate = word_error_rate_detail(
hypotheses=hyps, references=refs, use_cer=cfg.analyst.metric_calculator.use_cer
)
if "output_filename" not in cfg.analyst.metric_calculator or not cfg.analyst.metric_calculator.output_filename:
# overwrite the current generated manifest
OmegaConf.set_struct(cfg, True)
with open_dict(cfg):
cfg.analyst.metric_calculator.output_filename = cfg.engine.output_filename
with open(cfg.analyst.metric_calculator.output_filename, 'w') as fout:
for sample in samples:
json.dump(sample, fout)
fout.write('\n')
fout.flush()
total_res = {
"samples": len(samples),
"tokens": total_tokens,
eval_metric: total_wer,
"ins_rate": total_ins_rate,
"del_rate": total_del_rate,
"sub_rate": total_sub_rate,
}
return cfg, total_res, eval_metric
def cal_target_metadata_wer(manifest: str, target: str, meta_cfg: DictConfig, eval_metric: str = "wer",) -> dict:
"""
Caculating number of samples (samples), number of words/characters/tokens (tokens),
wer/cer, insertion error rate (ins_rate), deletion error rate (del_rate), substitution error rate (sub_rate) of the group/slot of target metadata.
The group could be [female, male] or slot group like [0-2s, 2-5s, >5s audios]
Args:
manifest (str): Filepath of the generated manifest which contains prediction and eval result for each samples.
target (str): Target metadata. Execute the target metadata if field presents in manifest.
such as 'duration', 'speaker', 'emotion', etc.
meta_cfg (DictConfig): Config for calculating group eval_metric for the target metadata.
eval_metric: (str): Supported evaluation metrics. Currently support 'wer' and 'cer'.
Return:
ret (dict): Generated dictionary containing all results regarding the target metadata.
"""
if eval_metric not in ['wer', 'cer']:
raise ValueError(
"Currently support wer and cer as eval_metric. Please implement it in cal_target_metadata_wer if using different eval_metric"
)
wer_per_class = {}
with open(manifest, 'r') as fp:
for line in fp:
sample = json.loads(line)
if target in sample:
target_class = sample[target]
if target_class not in wer_per_class:
wer_per_class[target_class] = {
'samples': 0,
'tokens': 0,
"errors": 0,
"inss": 0,
"dels": 0,
"subs": 0,
}
wer_per_class[target_class]['samples'] += 1
tokens = sample["tokens"]
wer_per_class[target_class]["tokens"] += tokens
wer_per_class[target_class]["errors"] += tokens * sample[eval_metric]
wer_per_class[target_class]["inss"] += tokens * sample["ins_rate"]
wer_per_class[target_class]["dels"] += tokens * sample["del_rate"]
wer_per_class[target_class]["subs"] += tokens * sample["sub_rate"]
if len(wer_per_class) > 0:
res_wer_per_class = {}
for target_class in wer_per_class:
res_wer_per_class[target_class] = {}
res_wer_per_class[target_class]["samples"] = wer_per_class[target_class]["samples"]
res_wer_per_class[target_class][eval_metric] = (
wer_per_class[target_class]["errors"] / wer_per_class[target_class]["tokens"]
)
res_wer_per_class[target_class]["tokens"] = wer_per_class[target_class]["tokens"]
res_wer_per_class[target_class]["ins_rate"] = (
wer_per_class[target_class]["inss"] / wer_per_class[target_class]["tokens"]
)
res_wer_per_class[target_class]["del_rate"] = (
wer_per_class[target_class]["dels"] / wer_per_class[target_class]["tokens"]
)
res_wer_per_class[target_class]["sub_rate"] = (
wer_per_class[target_class]["subs"] / wer_per_class[target_class]["tokens"]
)
else:
logging.info(f"metadata '{target}' does not present in manifest. Skipping! ")
return None
values = ['samples', 'tokens', 'errors', 'inss', 'dels', 'subs']
slot_wer = {}
if 'slot' in meta_cfg and meta_cfg.slot:
for target_class in wer_per_class:
for s in meta_cfg.slot:
if isinstance(s[0], float) or isinstance(s[0], int):
if s[0] <= target_class < s[1]:
slot_key = "slot-" + ",".join(str(i) for i in s)
if slot_key not in slot_wer:
slot_wer[slot_key] = {
'samples': 0,
'tokens': 0,
"errors": 0,
"inss": 0,
"dels": 0,
"subs": 0,
}
for v in values:
slot_wer[slot_key][v] += wer_per_class[target_class][v]
break
elif isinstance(s[0], str):
if target_class in s:
slot_key = "slot-" + ",".join(s)
if slot_key not in slot_wer:
slot_wer[slot_key] = {
'samples': 0,
'tokens': 0,
"errors": 0,
"inss": 0,
"dels": 0,
"subs": 0,
}
for v in values:
slot_wer[slot_key][v] += wer_per_class[target_class][v]
break
else:
raise ValueError("Current only support target metadata belongs to numeric or string ")
for slot_key in slot_wer:
slot_wer[slot_key][eval_metric] = slot_wer[slot_key]['errors'] / slot_wer[slot_key]['tokens']
slot_wer[slot_key]['ins_rate'] = slot_wer[slot_key]['inss'] / slot_wer[slot_key]['tokens']
slot_wer[slot_key]['del_rate'] = slot_wer[slot_key]['dels'] / slot_wer[slot_key]['tokens']
slot_wer[slot_key]['sub_rate'] = slot_wer[slot_key]['subs'] / slot_wer[slot_key]['tokens']
slot_wer[slot_key].pop('errors')
slot_wer[slot_key].pop('inss')
slot_wer[slot_key].pop('dels')
slot_wer[slot_key].pop('subs')
res_wer_per_class.update(slot_wer)
ret = None
if meta_cfg.save_wer_per_class:
ret = res_wer_per_class
if (not meta_cfg.save_wer_per_class) and ('slot' in meta_cfg and meta_cfg.slot):
ret = slot_wer
return ret