|
import numpy as np |
|
|
|
|
|
|
|
class FeatureExtractor: |
|
def __init__( |
|
self, |
|
feature_size=80, |
|
sampling_rate=16000, |
|
hop_length=160, |
|
chunk_length=30, |
|
n_fft=400, |
|
): |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.chunk_length = chunk_length |
|
self.n_samples = chunk_length * sampling_rate |
|
self.nb_max_frames = self.n_samples // hop_length |
|
self.time_per_frame = hop_length / sampling_rate |
|
self.sampling_rate = sampling_rate |
|
self.mel_filters = self.get_mel_filters( |
|
sampling_rate, n_fft, n_mels=feature_size |
|
) |
|
|
|
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): |
|
|
|
n_mels = int(n_mels) |
|
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) |
|
|
|
|
|
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) |
|
|
|
|
|
min_mel = 0.0 |
|
max_mel = 45.245640471924965 |
|
|
|
mels = np.linspace(min_mel, max_mel, n_mels + 2) |
|
|
|
mels = np.asanyarray(mels) |
|
|
|
|
|
f_min = 0.0 |
|
f_sp = 200.0 / 3 |
|
freqs = f_min + f_sp * mels |
|
|
|
|
|
min_log_hz = 1000.0 |
|
min_log_mel = (min_log_hz - f_min) / f_sp |
|
logstep = np.log(6.4) / 27.0 |
|
|
|
|
|
log_t = mels >= min_log_mel |
|
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) |
|
|
|
mel_f = freqs |
|
|
|
fdiff = np.diff(mel_f) |
|
ramps = np.subtract.outer(mel_f, fftfreqs) |
|
|
|
for i in range(n_mels): |
|
|
|
lower = -ramps[i] / fdiff[i] |
|
upper = ramps[i + 2] / fdiff[i + 1] |
|
|
|
|
|
weights[i] = np.maximum(0, np.minimum(lower, upper)) |
|
|
|
|
|
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) |
|
weights *= enorm[:, np.newaxis] |
|
|
|
return weights |
|
|
|
def fram_wave(self, waveform, center=True): |
|
""" |
|
Transform a raw waveform into a list of smaller waveforms. |
|
The window length defines how much of the signal is |
|
contain in each frame (smalle waveform), while the hope length defines the step |
|
between the beginning of each new frame. |
|
Centering is done by reflecting the waveform which is first centered around |
|
`frame_idx * hop_length`. |
|
""" |
|
frames = [] |
|
for i in range(0, waveform.shape[0] + 1, self.hop_length): |
|
half_window = (self.n_fft - 1) // 2 + 1 |
|
if center: |
|
start = i - half_window if i > half_window else 0 |
|
end = ( |
|
i + half_window |
|
if i < waveform.shape[0] - half_window |
|
else waveform.shape[0] |
|
) |
|
|
|
frame = waveform[start:end] |
|
|
|
if start == 0: |
|
padd_width = (-i + half_window, 0) |
|
frame = np.pad(frame, pad_width=padd_width, mode="reflect") |
|
|
|
elif end == waveform.shape[0]: |
|
padd_width = (0, (i - waveform.shape[0] + half_window)) |
|
frame = np.pad(frame, pad_width=padd_width, mode="reflect") |
|
|
|
else: |
|
frame = waveform[i : i + self.n_fft] |
|
frame_width = frame.shape[0] |
|
if frame_width < waveform.shape[0]: |
|
frame = np.lib.pad( |
|
frame, |
|
pad_width=(0, self.n_fft - frame_width), |
|
mode="constant", |
|
constant_values=0, |
|
) |
|
|
|
frames.append(frame) |
|
return np.stack(frames, 0) |
|
|
|
def stft(self, frames, window): |
|
""" |
|
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. |
|
Should give the same results as `torch.stft`. |
|
""" |
|
frame_size = frames.shape[1] |
|
fft_size = self.n_fft |
|
|
|
if fft_size is None: |
|
fft_size = frame_size |
|
|
|
if fft_size < frame_size: |
|
raise ValueError("FFT size must greater or equal the frame size") |
|
|
|
num_fft_bins = (fft_size >> 1) + 1 |
|
|
|
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64) |
|
fft_signal = np.zeros(fft_size) |
|
|
|
for f, frame in enumerate(frames): |
|
if window is not None: |
|
np.multiply(frame, window, out=fft_signal[:frame_size]) |
|
else: |
|
fft_signal[:frame_size] = frame |
|
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] |
|
return data.T |
|
|
|
def __call__(self, waveform, padding=True, chunk_length=None): |
|
""" |
|
Compute the log-Mel spectrogram of the provided audio, gives similar results |
|
whisper's original torch implementation with 1e-5 tolerance. |
|
""" |
|
if chunk_length is not None: |
|
self.n_samples = chunk_length * self.sampling_rate |
|
self.nb_max_frames = self.n_samples // self.hop_length |
|
|
|
if padding: |
|
waveform = np.pad(waveform, [(0, self.n_samples)]) |
|
|
|
window = np.hanning(self.n_fft + 1)[:-1] |
|
|
|
frames = self.fram_wave(waveform) |
|
stft = self.stft(frames, window=window) |
|
magnitudes = np.abs(stft[:, :-1]) ** 2 |
|
|
|
filters = self.mel_filters |
|
mel_spec = filters @ magnitudes |
|
|
|
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) |
|
log_spec = np.maximum(log_spec, log_spec.max() - 8.0) |
|
log_spec = (log_spec + 4.0) / 4.0 |
|
|
|
return log_spec |
|
|