Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# This module is modified from [Whisper](https://github.com/openai/whisper.git). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{openai-whisper, | |
# author = {Alec Radford and | |
# Jong Wook Kim and | |
# Tao Xu and | |
# Greg Brockman and | |
# Christine McLeavey and | |
# Ilya Sutskever}, | |
# title = {Robust Speech Recognition via Large-Scale Weak Supervision}, | |
# booktitle = {{ICML}}, | |
# series = {Proceedings of Machine Learning Research}, | |
# volume = {202}, | |
# pages = {28492--28518}, | |
# publisher = {{PMLR}}, | |
# year = {2023} | |
# } | |
# ``` | |
# | |
import os | |
from functools import lru_cache | |
from typing import Union | |
import ffmpeg | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from .utils import exact_div | |
# hard-coded audio hyperparameters | |
SAMPLE_RATE = 16000 | |
N_FFT = 400 | |
N_MELS = 80 | |
HOP_LENGTH = 160 | |
CHUNK_LENGTH = 30 | |
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk | |
N_FRAMES = exact_div( | |
N_SAMPLES, HOP_LENGTH | |
) # 3000: number of frames in a mel spectrogram input | |
def load_audio(file: str, sr: int = SAMPLE_RATE): | |
""" | |
Open an audio file and read as mono waveform, resampling as necessary | |
Parameters | |
---------- | |
file: str | |
The audio file to open | |
sr: int | |
The sample rate to resample the audio if necessary | |
Returns | |
------- | |
A NumPy array containing the audio waveform, in float32 dtype. | |
""" | |
try: | |
# This launches a subprocess to decode audio while down-mixing and resampling as necessary. | |
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. | |
out, _ = ( | |
ffmpeg.input(file, threads=0) | |
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) | |
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
) | |
except ffmpeg.Error as e: | |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | |
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): | |
""" | |
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. | |
""" | |
if torch.is_tensor(array): | |
if array.shape[axis] > length: | |
array = array.index_select( | |
dim=axis, index=torch.arange(length, device=array.device) | |
) | |
if array.shape[axis] < length: | |
pad_widths = [(0, 0)] * array.ndim | |
pad_widths[axis] = (0, length - array.shape[axis]) | |
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) | |
else: | |
if array.shape[axis] > length: | |
array = array.take(indices=range(length), axis=axis) | |
if array.shape[axis] < length: | |
pad_widths = [(0, 0)] * array.ndim | |
pad_widths[axis] = (0, length - array.shape[axis]) | |
array = np.pad(array, pad_widths) | |
return array | |
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: | |
""" | |
load the mel filterbank matrix for projecting STFT into a Mel spectrogram. | |
Allows decoupling librosa dependency; saved using: | |
np.savez_compressed( | |
"mel_filters.npz", | |
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), | |
) | |
""" | |
assert n_mels == 80, f"Unsupported n_mels: {n_mels}" | |
with np.load( | |
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") | |
) as f: | |
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) | |
def log_mel_spectrogram( | |
audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS | |
): | |
""" | |
Compute the log-Mel spectrogram of | |
Parameters | |
---------- | |
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | |
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | |
n_mels: int | |
The number of Mel-frequency filters, only 80 is supported | |
Returns | |
------- | |
torch.Tensor, shape = (80, n_frames) | |
A Tensor that contains the Mel spectrogram | |
""" | |
if not torch.is_tensor(audio): | |
if isinstance(audio, str): | |
audio = load_audio(audio) | |
audio = torch.from_numpy(audio) | |
window = torch.hann_window(N_FFT).to(audio.device) | |
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) | |
magnitudes = stft[..., :-1].abs() ** 2 | |
filters = mel_filters(audio.device, n_mels) | |
mel_spec = filters @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |