P-FAD / src /frontends.py
mrneuralnet's picture
Initial commit
833847d
from typing import List, Union, Callable
import torch
import torchaudio
SAMPLING_RATE = 16_000
win_length = 400 # int((25 / 1_000) * SAMPLING_RATE)
hop_length = 160 # int((10 / 1_000) * SAMPLING_RATE)
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'
MFCC_FN = torchaudio.transforms.MFCC(
sample_rate=SAMPLING_RATE,
n_mfcc=128,
melkwargs={
"n_fft": 512,
"win_length": win_length,
"hop_length": hop_length,
},
).to(device)
LFCC_FN = torchaudio.transforms.LFCC(
sample_rate=SAMPLING_RATE,
n_lfcc=128,
speckwargs={
"n_fft": 512,
"win_length": win_length,
"hop_length": hop_length,
},
).to(device)
MEL_SCALE_FN = torchaudio.transforms.MelScale(
n_mels=80,
n_stft=257,
sample_rate=SAMPLING_RATE,
).to(device)
delta_fn = torchaudio.transforms.ComputeDeltas(
win_length=400,
mode="replicate",
).to(device)
def get_frontend(
frontends: List[str],
) -> Union[torchaudio.transforms.MFCC, torchaudio.transforms.LFCC, Callable,]:
if "mfcc" in frontends:
return prepare_mfcc_double_delta
elif "lfcc" in frontends:
return prepare_lfcc_double_delta
raise ValueError(f"{frontends} frontend is not supported!")
def prepare_lfcc_double_delta(input):
if input.ndim < 4:
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
x = LFCC_FN(input)
delta = delta_fn(x)
double_delta = delta_fn(delta)
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)
def prepare_mfcc_double_delta(input):
if input.ndim < 4:
input = input.unsqueeze(1) # (bs, 1, n_lfcc, frames)
input.to(device)
x = MFCC_FN(input)
delta = delta_fn(x)
double_delta = delta_fn(delta)
x = torch.cat((x, delta, double_delta), 2) # -> [bs, 1, 128 * 3, 1500]
return x[:, :, :, :3000] # (bs, n, n_lfcc * 3, frames)