|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import subprocess |
|
import tempfile |
|
from pathlib import Path |
|
from typing import Tuple |
|
|
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
|
|
from nemo.collections.asr.metrics.wer import word_error_rate_detail |
|
from nemo.utils import logging |
|
|
|
|
|
def run_asr_inference(cfg: DictConfig) -> DictConfig: |
|
""" |
|
Execute ASR inference based on input mode and parameters. |
|
""" |
|
if (cfg.model_path and cfg.pretrained_name) or (not cfg.model_path and not cfg.pretrained_name): |
|
raise ValueError("Please specify either cfg.model_path or cfg.pretrained_name!") |
|
|
|
if cfg.inference.mode == "offline": |
|
cfg = run_offline_inference(cfg) |
|
|
|
elif cfg.inference.mode == "chunked": |
|
if ( |
|
"total_buffer_in_secs" not in cfg.inference |
|
or "chunk_len_in_secs" not in cfg.inference |
|
or not cfg.inference.total_buffer_in_secs |
|
or not cfg.inference.chunk_len_in_secs |
|
): |
|
raise ValueError(f"Please specify both total_buffer_in_secs and chunk_len_in_secs for chunked inference") |
|
cfg = run_chunked_inference(cfg) |
|
|
|
elif cfg.inference.mode == "offline_by_chunked": |
|
|
|
|
|
|
|
OmegaConf.set_struct(cfg, True) |
|
if 'total_buffer_in_secs' not in cfg.inference or not cfg.inference.total_buffer_in_secs: |
|
with open_dict(cfg): |
|
cfg.inference.total_buffer_in_secs = 22 |
|
logging.info( |
|
f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.total_buffer_in_secs}" |
|
) |
|
if 'chunk_len_in_secs' not in cfg.inference or not cfg.inference.chunk_len_in_secs: |
|
with open_dict(cfg): |
|
cfg.inference.chunk_len_in_secs = 20 |
|
logging.info( |
|
f"Does not provide total_buffer_in_secs required by {cfg.inference.mode} mode. Using default value {cfg.inference.chunk_len_in_secs}" |
|
) |
|
cfg = run_chunked_inference(cfg) |
|
|
|
else: |
|
raise ValueError(f"inference could only be offline or chunked, but got {cfg.inference.mode}") |
|
|
|
return cfg |
|
|
|
|
|
def run_chunked_inference(cfg: DictConfig) -> DictConfig: |
|
if "output_filename" not in cfg or not cfg.output_filename: |
|
if cfg.model_path: |
|
model_name = Path(cfg.model_path).stem |
|
else: |
|
model_name = cfg.pretrained_name |
|
dataset_name = Path(cfg.test_ds.manifest_filepath).stem |
|
mode_name = ( |
|
cfg.inference.mode |
|
+ "B" |
|
+ str(cfg.inference.total_buffer_in_secs) |
|
+ "C" |
|
+ str(cfg.inference.chunk_len_in_secs) |
|
) |
|
|
|
OmegaConf.set_struct(cfg, True) |
|
with open_dict(cfg): |
|
cfg.output_filename = model_name + "-" + dataset_name + "-" + mode_name + ".json" |
|
|
|
script_path = ( |
|
Path(__file__).parents[2] |
|
/ "examples" |
|
/ "asr" |
|
/ "asr_chunked_inference" |
|
/ "ctc" |
|
/ "speech_to_text_buffered_infer_ctc.py" |
|
) |
|
|
|
if (cfg.pretrained_name and 'transducer' in cfg.pretrained_name) or ( |
|
cfg.model_path and 'transducer' in cfg.model_path |
|
): |
|
script_path = ( |
|
Path(__file__).parents[2] |
|
/ "examples" |
|
/ "asr" |
|
/ "asr_chunked_inference" |
|
/ "rnnt" |
|
/ "speech_to_text_buffered_infer_rnnt.py" |
|
) |
|
|
|
subprocess.run( |
|
f"python {script_path} " |
|
f"model_path={cfg.model_path} " |
|
f"pretrained_name={cfg.pretrained_name} " |
|
f"dataset_manifest={cfg.test_ds.manifest_filepath} " |
|
f"output_filename={cfg.output_filename} " |
|
f"random_seed={cfg.random_seed} " |
|
f"batch_size={cfg.test_ds.batch_size} " |
|
f"chunk_len_in_secs={cfg.inference.chunk_len_in_secs} " |
|
f"total_buffer_in_secs={cfg.inference.total_buffer_in_secs} " |
|
f"model_stride={cfg.inference.model_stride} ", |
|
shell=True, |
|
check=True, |
|
) |
|
return cfg |
|
|
|
|
|
def run_offline_inference(cfg: DictConfig) -> DictConfig: |
|
if "output_filename" not in cfg or not cfg.output_filename: |
|
if cfg.model_path: |
|
model_name = Path(cfg.model_path).stem |
|
else: |
|
model_name = cfg.pretrained_name |
|
dataset_name = Path(cfg.test_ds.manifest_filepath).stem |
|
mode_name = cfg.inference.mode |
|
|
|
OmegaConf.set_struct(cfg, True) |
|
with open_dict(cfg): |
|
cfg.output_filename = model_name + "-" + dataset_name + "-" + mode_name + ".json" |
|
|
|
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') as f: |
|
OmegaConf.save(cfg, f) |
|
f.seek(0) |
|
script_path = Path(__file__).parents[2] / "examples" / "asr" / "transcribe_speech.py" |
|
|
|
|
|
|
|
|
|
subprocess.run( |
|
f"python {script_path} " |
|
f"model_path={cfg.model_path} " |
|
f"pretrained_name={cfg.pretrained_name} " |
|
f"dataset_manifest={cfg.test_ds.manifest_filepath} " |
|
f"output_filename={cfg.output_filename} " |
|
f"batch_size={cfg.test_ds.batch_size} " |
|
f"random_seed={cfg.random_seed} " |
|
f"eval_config_yaml={f.name} ", |
|
shell=True, |
|
check=True, |
|
) |
|
|
|
return cfg |
|
|
|
|
|
def clean_label(_str: str, num_to_words: bool = True, langid="en") -> str: |
|
""" |
|
Remove unauthorized characters in a string, lower it and remove unneeded spaces |
|
""" |
|
replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→'] |
|
replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”'] |
|
replace_with_apos = [char for char in '‘’ʻ‘’‘'] |
|
_str = _str.strip() |
|
_str = _str.lower() |
|
for i in replace_with_blank: |
|
_str = _str.replace(i, "") |
|
for i in replace_with_space: |
|
_str = _str.replace(i, " ") |
|
for i in replace_with_apos: |
|
_str = _str.replace(i, "'") |
|
if num_to_words: |
|
if langid == "en": |
|
_str = convert_num_to_words(_str, langid="en") |
|
else: |
|
logging.info( |
|
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages! Skipping!" |
|
) |
|
|
|
ret = " ".join(_str.split()) |
|
return ret |
|
|
|
|
|
def convert_num_to_words(_str: str, langid: str = "en") -> str: |
|
""" |
|
Convert digits to corresponding words. Note this is a naive approach and could be replaced with text normalization. |
|
""" |
|
if langid == "en": |
|
num_to_words = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] |
|
_str = _str.strip() |
|
words = _str.split() |
|
out_str = "" |
|
num_word = [] |
|
for word in words: |
|
if word.isdigit(): |
|
num = int(word) |
|
while num: |
|
digit = num % 10 |
|
digit_word = num_to_words[digit] |
|
num_word.append(digit_word) |
|
num = int(num / 10) |
|
if not (num): |
|
num_str = "" |
|
num_word = num_word[::-1] |
|
for ele in num_word: |
|
num_str += ele + " " |
|
out_str += num_str + " " |
|
num_word.clear() |
|
else: |
|
out_str += word + " " |
|
out_str = out_str.strip() |
|
else: |
|
raise ValueError( |
|
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages!" |
|
) |
|
return out_str |
|
|
|
|
|
def cal_write_wer(cfg: DictConfig, pred_text_attr_name: str = None) -> Tuple[DictConfig, dict]: |
|
""" |
|
Calculate wer, inserion, deletion and substitution rate based on groundtruth text and pred_text_attr_name (pred_text) |
|
We use WER in function name as a convention, but it currently Error Rate (ER) support Word Error Rate (WER) and Character Error Rate (CER) |
|
""" |
|
samples = [] |
|
hyps = [] |
|
refs = [] |
|
|
|
with open(cfg.engine.output_filename, 'r') as fp: |
|
for line in fp: |
|
sample = json.loads(line) |
|
|
|
if 'text' not in sample: |
|
raise ValueError( |
|
"ground-truth text does not present in manifest! Cannot calculate Word Error Rate. Exiting!" |
|
) |
|
|
|
if not pred_text_attr_name: |
|
pred_text_attr_name = "pred_text" |
|
|
|
hyp = sample[pred_text_attr_name] |
|
ref = sample['text'] |
|
|
|
if cfg.analyst.metric_calculator.clean_groundtruth_text: |
|
ref = clean_label(ref, langid=cfg.analyst.metric_calculator.langid) |
|
|
|
wer, tokens, ins_rate, del_rate, sub_rate = word_error_rate_detail( |
|
hypotheses=[hyp], references=[ref], use_cer=cfg.analyst.metric_calculator.use_cer |
|
) |
|
eval_metric = "wer" |
|
if cfg.analyst.metric_calculator.use_cer: |
|
eval_metric = "cer" |
|
|
|
sample[eval_metric] = wer |
|
sample['tokens'] = tokens |
|
sample['ins_rate'] = ins_rate |
|
sample['del_rate'] = del_rate |
|
sample['sub_rate'] = sub_rate |
|
|
|
samples.append(sample) |
|
hyps.append(hyp) |
|
refs.append(ref) |
|
|
|
total_wer, total_tokens, total_ins_rate, total_del_rate, total_sub_rate = word_error_rate_detail( |
|
hypotheses=hyps, references=refs, use_cer=cfg.analyst.metric_calculator.use_cer |
|
) |
|
|
|
if "output_filename" not in cfg.analyst.metric_calculator or not cfg.analyst.metric_calculator.output_filename: |
|
|
|
OmegaConf.set_struct(cfg, True) |
|
with open_dict(cfg): |
|
cfg.analyst.metric_calculator.output_filename = cfg.engine.output_filename |
|
|
|
with open(cfg.analyst.metric_calculator.output_filename, 'w') as fout: |
|
for sample in samples: |
|
json.dump(sample, fout) |
|
fout.write('\n') |
|
fout.flush() |
|
|
|
total_res = { |
|
"samples": len(samples), |
|
"tokens": total_tokens, |
|
eval_metric: total_wer, |
|
"ins_rate": total_ins_rate, |
|
"del_rate": total_del_rate, |
|
"sub_rate": total_sub_rate, |
|
} |
|
return cfg, total_res, eval_metric |
|
|
|
|
|
def cal_target_metadata_wer(manifest: str, target: str, meta_cfg: DictConfig, eval_metric: str = "wer",) -> dict: |
|
""" |
|
Caculating number of samples (samples), number of words/characters/tokens (tokens), |
|
wer/cer, insertion error rate (ins_rate), deletion error rate (del_rate), substitution error rate (sub_rate) of the group/slot of target metadata. |
|
|
|
The group could be [female, male] or slot group like [0-2s, 2-5s, >5s audios] |
|
|
|
|
|
Args: |
|
manifest (str): Filepath of the generated manifest which contains prediction and eval result for each samples. |
|
target (str): Target metadata. Execute the target metadata if field presents in manifest. |
|
such as 'duration', 'speaker', 'emotion', etc. |
|
meta_cfg (DictConfig): Config for calculating group eval_metric for the target metadata. |
|
eval_metric: (str): Supported evaluation metrics. Currently support 'wer' and 'cer'. |
|
|
|
Return: |
|
ret (dict): Generated dictionary containing all results regarding the target metadata. |
|
""" |
|
|
|
if eval_metric not in ['wer', 'cer']: |
|
raise ValueError( |
|
"Currently support wer and cer as eval_metric. Please implement it in cal_target_metadata_wer if using different eval_metric" |
|
) |
|
|
|
wer_per_class = {} |
|
with open(manifest, 'r') as fp: |
|
for line in fp: |
|
sample = json.loads(line) |
|
if target in sample: |
|
target_class = sample[target] |
|
if target_class not in wer_per_class: |
|
wer_per_class[target_class] = { |
|
'samples': 0, |
|
'tokens': 0, |
|
"errors": 0, |
|
"inss": 0, |
|
"dels": 0, |
|
"subs": 0, |
|
} |
|
wer_per_class[target_class]['samples'] += 1 |
|
|
|
tokens = sample["tokens"] |
|
wer_per_class[target_class]["tokens"] += tokens |
|
wer_per_class[target_class]["errors"] += tokens * sample[eval_metric] |
|
wer_per_class[target_class]["inss"] += tokens * sample["ins_rate"] |
|
wer_per_class[target_class]["dels"] += tokens * sample["del_rate"] |
|
wer_per_class[target_class]["subs"] += tokens * sample["sub_rate"] |
|
|
|
if len(wer_per_class) > 0: |
|
res_wer_per_class = {} |
|
for target_class in wer_per_class: |
|
res_wer_per_class[target_class] = {} |
|
res_wer_per_class[target_class]["samples"] = wer_per_class[target_class]["samples"] |
|
res_wer_per_class[target_class][eval_metric] = ( |
|
wer_per_class[target_class]["errors"] / wer_per_class[target_class]["tokens"] |
|
) |
|
res_wer_per_class[target_class]["tokens"] = wer_per_class[target_class]["tokens"] |
|
res_wer_per_class[target_class]["ins_rate"] = ( |
|
wer_per_class[target_class]["inss"] / wer_per_class[target_class]["tokens"] |
|
) |
|
res_wer_per_class[target_class]["del_rate"] = ( |
|
wer_per_class[target_class]["dels"] / wer_per_class[target_class]["tokens"] |
|
) |
|
res_wer_per_class[target_class]["sub_rate"] = ( |
|
wer_per_class[target_class]["subs"] / wer_per_class[target_class]["tokens"] |
|
) |
|
else: |
|
logging.info(f"metadata '{target}' does not present in manifest. Skipping! ") |
|
return None |
|
|
|
values = ['samples', 'tokens', 'errors', 'inss', 'dels', 'subs'] |
|
slot_wer = {} |
|
if 'slot' in meta_cfg and meta_cfg.slot: |
|
for target_class in wer_per_class: |
|
for s in meta_cfg.slot: |
|
if isinstance(s[0], float) or isinstance(s[0], int): |
|
if s[0] <= target_class < s[1]: |
|
slot_key = "slot-" + ",".join(str(i) for i in s) |
|
if slot_key not in slot_wer: |
|
slot_wer[slot_key] = { |
|
'samples': 0, |
|
'tokens': 0, |
|
"errors": 0, |
|
"inss": 0, |
|
"dels": 0, |
|
"subs": 0, |
|
} |
|
|
|
for v in values: |
|
slot_wer[slot_key][v] += wer_per_class[target_class][v] |
|
break |
|
|
|
elif isinstance(s[0], str): |
|
if target_class in s: |
|
slot_key = "slot-" + ",".join(s) |
|
if slot_key not in slot_wer: |
|
slot_wer[slot_key] = { |
|
'samples': 0, |
|
'tokens': 0, |
|
"errors": 0, |
|
"inss": 0, |
|
"dels": 0, |
|
"subs": 0, |
|
} |
|
|
|
for v in values: |
|
slot_wer[slot_key][v] += wer_per_class[target_class][v] |
|
break |
|
else: |
|
raise ValueError("Current only support target metadata belongs to numeric or string ") |
|
|
|
for slot_key in slot_wer: |
|
slot_wer[slot_key][eval_metric] = slot_wer[slot_key]['errors'] / slot_wer[slot_key]['tokens'] |
|
slot_wer[slot_key]['ins_rate'] = slot_wer[slot_key]['inss'] / slot_wer[slot_key]['tokens'] |
|
slot_wer[slot_key]['del_rate'] = slot_wer[slot_key]['dels'] / slot_wer[slot_key]['tokens'] |
|
slot_wer[slot_key]['sub_rate'] = slot_wer[slot_key]['subs'] / slot_wer[slot_key]['tokens'] |
|
slot_wer[slot_key].pop('errors') |
|
slot_wer[slot_key].pop('inss') |
|
slot_wer[slot_key].pop('dels') |
|
slot_wer[slot_key].pop('subs') |
|
res_wer_per_class.update(slot_wer) |
|
|
|
ret = None |
|
if meta_cfg.save_wer_per_class: |
|
ret = res_wer_per_class |
|
if (not meta_cfg.save_wer_per_class) and ('slot' in meta_cfg and meta_cfg.slot): |
|
ret = slot_wer |
|
return ret |
|
|