Spaces:
Paused
Paused
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
from collections import defaultdict | |
from functools import lru_cache | |
from pathlib import Path | |
from subprocess import CalledProcessError, run | |
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union | |
import kaldialign | |
import numpy as np | |
import soundfile | |
import torch | |
import torch.nn.functional as F | |
Pathlike = Union[str, Path] | |
SAMPLE_RATE = 16000 | |
N_FFT = 400 | |
HOP_LENGTH = 160 | |
CHUNK_LENGTH = 30 | |
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk | |
def load_audio(file: str, sr: int = SAMPLE_RATE): | |
""" | |
Open an audio file and read as mono waveform, resampling as necessary | |
Parameters | |
---------- | |
file: str | |
The audio file to open | |
sr: int | |
The sample rate to resample the audio if necessary | |
Returns | |
------- | |
A NumPy array containing the audio waveform, in float32 dtype. | |
""" | |
# This launches a subprocess to decode audio while down-mixing | |
# and resampling as necessary. Requires the ffmpeg CLI in PATH. | |
# fmt: off | |
cmd = [ | |
"ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", | |
"1", "-acodec", "pcm_s16le", "-ar", | |
str(sr), "-" | |
] | |
# fmt: on | |
try: | |
out = run(cmd, capture_output=True, check=True).stdout | |
except CalledProcessError as e: | |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | |
def load_audio_wav_format(wav_path): | |
# make sure audio in .wav format | |
assert wav_path.endswith( | |
'.wav'), f"Only support .wav format, but got {wav_path}" | |
waveform, sample_rate = soundfile.read(wav_path) | |
assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" | |
return waveform, sample_rate | |
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): | |
""" | |
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. | |
""" | |
if torch.is_tensor(array): | |
if array.shape[axis] > length: | |
array = array.index_select(dim=axis, | |
index=torch.arange(length, | |
device=array.device)) | |
if array.shape[axis] < length: | |
pad_widths = [(0, 0)] * array.ndim | |
pad_widths[axis] = (0, length - array.shape[axis]) | |
array = F.pad(array, | |
[pad for sizes in pad_widths[::-1] for pad in sizes]) | |
else: | |
if array.shape[axis] > length: | |
array = array.take(indices=range(length), axis=axis) | |
if array.shape[axis] < length: | |
pad_widths = [(0, 0)] * array.ndim | |
pad_widths[axis] = (0, length - array.shape[axis]) | |
array = np.pad(array, pad_widths) | |
return array | |
def mel_filters(device, | |
n_mels: int, | |
mel_filters_dir: str = None) -> torch.Tensor: | |
""" | |
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. | |
Allows decoupling librosa dependency; saved using: | |
np.savez_compressed( | |
"mel_filters.npz", | |
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), | |
) | |
""" | |
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" | |
if mel_filters_dir is None: | |
mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", | |
"mel_filters.npz") | |
else: | |
mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") | |
with np.load(mel_filters_path) as f: | |
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) | |
def log_mel_spectrogram( | |
audio: Union[str, np.ndarray, torch.Tensor], | |
n_mels: int, | |
padding: int = 0, | |
device: Optional[Union[str, torch.device]] = None, | |
return_duration: bool = False, | |
mel_filters_dir: str = None, | |
): | |
""" | |
Compute the log-Mel spectrogram of | |
Parameters | |
---------- | |
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | |
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | |
n_mels: int | |
The number of Mel-frequency filters, only 80 and 128 are supported | |
padding: int | |
Number of zero samples to pad to the right | |
device: Optional[Union[str, torch.device]] | |
If given, the audio tensor is moved to this device before STFT | |
Returns | |
------- | |
torch.Tensor, shape = (80 or 128, n_frames) | |
A Tensor that contains the Mel spectrogram | |
""" | |
if not torch.is_tensor(audio): | |
if isinstance(audio, str): | |
if audio.endswith('.wav'): | |
audio, _ = load_audio_wav_format(audio) | |
else: | |
audio = load_audio(audio) | |
assert isinstance(audio, | |
np.ndarray), f"Unsupported audio type: {type(audio)}" | |
duration = audio.shape[-1] / SAMPLE_RATE | |
audio = pad_or_trim(audio, N_SAMPLES) | |
audio = audio.astype(np.float32) | |
audio = torch.from_numpy(audio) | |
if device is not None: | |
audio = audio.to(device) | |
if padding > 0: | |
audio = F.pad(audio, (0, padding)) | |
window = torch.hann_window(N_FFT).to(audio.device) | |
stft = torch.stft(audio, | |
N_FFT, | |
HOP_LENGTH, | |
window=window, | |
return_complex=True) | |
magnitudes = stft[..., :-1].abs()**2 | |
filters = mel_filters(audio.device, n_mels, mel_filters_dir) | |
mel_spec = filters @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
if return_duration: | |
return log_spec, duration | |
else: | |
return log_spec | |
def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, | |
str]]) -> None: | |
"""Save predicted results and reference transcripts to a file. | |
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py | |
Args: | |
filename: | |
File to save the results to. | |
texts: | |
An iterable of tuples. The first element is the cur_id, the second is | |
the reference transcript and the third element is the predicted result. | |
Returns: | |
Return None. | |
""" | |
with open(filename, "w") as f: | |
for cut_id, ref, hyp in texts: | |
print(f"{cut_id}:\tref={ref}", file=f) | |
print(f"{cut_id}:\thyp={hyp}", file=f) | |
def write_error_stats( | |
f: TextIO, | |
test_set_name: str, | |
results: List[Tuple[str, str]], | |
enable_log: bool = True, | |
) -> float: | |
"""Write statistics based on predicted results and reference transcripts. | |
https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py | |
It will write the following to the given file: | |
- WER | |
- number of insertions, deletions, substitutions, corrects and total | |
reference words. For example:: | |
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 | |
reference words (2337 correct) | |
- The difference between the reference transcript and predicted result. | |
An instance is given below:: | |
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES | |
The above example shows that the reference word is `EDISON`, | |
but it is predicted to `ADDISON` (a substitution error). | |
Another example is:: | |
FOR THE FIRST DAY (SIR->*) I THINK | |
The reference word `SIR` is missing in the predicted | |
results (a deletion error). | |
results: | |
An iterable of tuples. The first element is the cur_id, the second is | |
the reference transcript and the third element is the predicted result. | |
enable_log: | |
If True, also print detailed WER to the console. | |
Otherwise, it is written only to the given file. | |
Returns: | |
Return None. | |
""" | |
subs: Dict[Tuple[str, str], int] = defaultdict(int) | |
ins: Dict[str, int] = defaultdict(int) | |
dels: Dict[str, int] = defaultdict(int) | |
# `words` stores counts per word, as follows: | |
# corr, ref_sub, hyp_sub, ins, dels | |
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) | |
num_corr = 0 | |
ERR = "*" | |
for cut_id, ref, hyp in results: | |
ali = kaldialign.align(ref, hyp, ERR) | |
for ref_word, hyp_word in ali: | |
if ref_word == ERR: | |
ins[hyp_word] += 1 | |
words[hyp_word][3] += 1 | |
elif hyp_word == ERR: | |
dels[ref_word] += 1 | |
words[ref_word][4] += 1 | |
elif hyp_word != ref_word: | |
subs[(ref_word, hyp_word)] += 1 | |
words[ref_word][1] += 1 | |
words[hyp_word][2] += 1 | |
else: | |
words[ref_word][0] += 1 | |
num_corr += 1 | |
ref_len = sum([len(r) for _, r, _ in results]) | |
sub_errs = sum(subs.values()) | |
ins_errs = sum(ins.values()) | |
del_errs = sum(dels.values()) | |
tot_errs = sub_errs + ins_errs + del_errs | |
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) | |
if enable_log: | |
logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " | |
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " | |
f"{del_errs} del, {sub_errs} sub ]") | |
print(f"%WER = {tot_err_rate}", file=f) | |
print( | |
f"Errors: {ins_errs} insertions, {del_errs} deletions, " | |
f"{sub_errs} substitutions, over {ref_len} reference " | |
f"words ({num_corr} correct)", | |
file=f, | |
) | |
print( | |
"Search below for sections starting with PER-UTT DETAILS:, " | |
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", | |
file=f, | |
) | |
print("", file=f) | |
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) | |
for cut_id, ref, hyp in results: | |
ali = kaldialign.align(ref, hyp, ERR) | |
combine_successive_errors = True | |
if combine_successive_errors: | |
ali = [[[x], [y]] for x, y in ali] | |
for i in range(len(ali) - 1): | |
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: | |
ali[i + 1][0] = ali[i][0] + ali[i + 1][0] | |
ali[i + 1][1] = ali[i][1] + ali[i + 1][1] | |
ali[i] = [[], []] | |
ali = [[ | |
list(filter(lambda a: a != ERR, x)), | |
list(filter(lambda a: a != ERR, y)), | |
] for x, y in ali] | |
ali = list(filter(lambda x: x != [[], []], ali)) | |
ali = [[ | |
ERR if x == [] else " ".join(x), | |
ERR if y == [] else " ".join(y), | |
] for x, y in ali] | |
print( | |
f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else | |
f"({ref_word}->{hyp_word})" | |
for ref_word, hyp_word in ali)), | |
file=f, | |
) | |
print("", file=f) | |
print("SUBSTITUTIONS: count ref -> hyp", file=f) | |
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], | |
reverse=True): | |
print(f"{count} {ref} -> {hyp}", file=f) | |
print("", file=f) | |
print("DELETIONS: count ref", file=f) | |
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): | |
print(f"{count} {ref}", file=f) | |
print("", file=f) | |
print("INSERTIONS: count hyp", file=f) | |
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): | |
print(f"{count} {hyp}", file=f) | |
print("", file=f) | |
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", | |
file=f) | |
for _, word, counts in sorted([(sum(v[1:]), k, v) | |
for k, v in words.items()], | |
reverse=True): | |
(corr, ref_sub, hyp_sub, ins, dels) = counts | |
tot_errs = ref_sub + hyp_sub + ins + dels | |
ref_count = corr + ref_sub + dels | |
hyp_count = corr + hyp_sub + ins | |
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) | |
return float(tot_err_rate) | |