Hecheng0625's picture
Create melspec.py
a63132d verified
raw
history blame contribute delete
No virus
2.62 kB
import torch
import pyworld as pw
import numpy as np
import soundfile as sf
import os
from torchaudio.functional import pitch_shift
import librosa
from librosa.filters import mel as librosa_mel_fn
import torch.nn as nn
import torch.nn.functional as F
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
class MelSpectrogram(nn.Module):
def __init__(
self,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
center=False,
):
super(MelSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_size = hop_size
self.win_size = win_size
self.sampling_rate = sampling_rate
self.num_mels = num_mels
self.fmin = fmin
self.fmax = fmax
self.center = center
mel_basis = {}
hann_window = {}
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis = torch.from_numpy(mel).float()
hann_window = torch.hann_window(win_size)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("hann_window", hann_window)
def forward(self, y):
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec)
return spec