dmsp / src /utils /audio.py
szin94's picture
first commit
bc3e180
import math
import torch
import torch.nn.functional as F
import numpy as np
import librosa
import soundfile as sf
from einops import rearrange
eps = np.finfo(np.float32).eps
def calculate_rms(amp):
if isinstance(amp, torch.Tensor):
return amp.pow(2).mean(-1, keepdim=True).pow(.5)
elif isinstance(amp, np.ndarray):
return np.sqrt(np.mean(np.square(amp), axis=-1) + eps)
else:
raise TypeError(f"argument 'amp' must be torch.Tensor or np.ndarray. got: {type(amp)}")
def dB2amp(dB):
return np.power(10., dB/20.)
def amp2dB(amp):
return 20. * np.log10(amp)
def rms_normalize(wav, ref_dBFS=-23.0, skip_nan=True):
exists_nan = np.isnan(np.sum(wav))
if not skip_nan:
assert not exists_nan, np.isnan(wav)
if exists_nan:
return wav, 1.
# RMS normalize
# value_dBFS = 20*log10(rms(signal) * sqrt(2)) = 20*log10(rms(signal)) + 3.0103
rms = calculate_rms(wav)
if isinstance(ref_dBFS, torch.Tensor):
ref_linear = torch.pow(10, (ref_dBFS-3.0103)/20.)
else:
ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
gain = ref_linear / (rms + eps)
wav = gain * wav
return wav, gain
def ell_infty_normalize(wav, skip_nan=True):
if isinstance(wav, np.ndarray):
''' numpy '''
exists_nan = np.isnan(np.sum(wav))
if not skip_nan:
assert not exists_nan, np.isnan(wav)
if exists_nan:
return wav, 1.
maxv = np.max(np.abs(wav), axis=-1)
# 1 if maxv == 0 else 1. / maxv
if len(list(maxv.shape)) == 0:
gain = 1 if maxv==0 else 1. / maxv
else:
gain = 1. / maxv; gain[maxv==0] = 1
elif isinstance(wav, torch.Tensor):
''' torch '''
exists_nan = torch.isnan(wav.sum())
if not skip_nan:
assert not exists_nan, torch.isnan(wav)
if exists_nan:
return wav, 1.
maxv = wav.abs().max(-1).values.unsqueeze(-1)
# 1 if maxv == 0 else 1. / maxv
gain = torch.where(maxv.eq(0),
torch.ones_like(maxv), 1. / maxv)
else:
assert False, wav
wav = gain * wav
return wav, gain
def dB_RMS(wav):
if isinstance(wav, torch.Tensor):
return 20 * torch.log10(calculate_rms(wav))
elif isinstance(wav, np.ndarray):
return 20 * np.log10(calculate_rms(wav))
def mel_basis(sr, n_fft, n_mel):
return librosa.filters.mel(sr=sr,n_fft=n_fft,n_mels=n_mel,fmin=0,fmax=sr//2,norm=1)
def inv_mel_basis(sr, n_fft, n_mel):
return librosa.filters.mel(
sr=sr, n_fft=n_fft, n_mels=n_mel, norm=None, fmin=0, fmax=sr//2,
).T
def lin_to_mel(linspec, sr, n_fft, n_mel=80):
basis = mel_basis(sr, n_fft, n_mel)
return basis @ linspec
def save_waves(est, save_dir, sr=16000):
data = []
batch_size = inp.shape[0]
for b in range(batch_size):
est_wav = est[b,0].squeeze()
wave_path = f"{save_dir}/{b}.wav"
sf.write(wave_path, est_wav, samplerate=sr)
def get_inverse_window(forward_window, frame_length, frame_step):
denom = torch.square(forward_window)
overlaps = -(-frame_length // frame_step) # Ceiling division.
denom = F.pad(denom, (0, overlaps * frame_step - frame_length))
denom = denom.reshape(overlaps, frame_step)
denom = denom.sum(0, keepdims=True)
denom = denom.tile(overlaps, 1)
denom = denom.reshape(overlaps * frame_step)
return forward_window / denom[:frame_length]
def state_to_wav(state, normalize=True, sr=48000):
''' state: (Bs, Nt, Nx) '''
assert len(list(state.shape)) == 3, state.shape
Nt = state.size(1)
vel = ((state.narrow(1,1,Nt-1) - state.narrow(1,0,Nt-1)) * sr).sum(-1)
return ell_infty_normalize(vel)[0] if normalize else vel
def state_to_spec(x, window):
''' x: (Bs, Nt, Nx, Ch)
-> (Bs, Nt, Nx, Ch*n_fft*2)
'''
Bs, Nt, Nx, Ch = x.shape
n_ffts = window.size(-1)
n_freq = n_ffts // 2 + 1
hop_length = n_ffts // 4
x = rearrange(x, 'b t x c -> (b x c) t')
s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
s = rearrange(s, '(b x c) f t k -> b t x (c f k)',
b=Bs, x=Nx, c=Ch, f=n_freq, k=2)
return s
def spec_to_state(x, window, length):
''' x: (Bs, Nt, Nx, Ch*n_fft*2)
-> (Bs, Nt, Nx, Ch)
'''
Bs, Nt, Nx, _ = x.shape
n_ffts = window.size(-1)
n_freq = n_ffts // 2 + 1
x = rearrange(x, 'b t x (c f k) -> (b x c) f t k', f=n_freq, k=2)
x = torch.istft(x, n_ffts, length=length, window=window)
x = rearrange(x, '(b x c) t -> b t x c', b=Bs, x=Nx)
return x
def to_spec(x, window, reduce_channel=True):
''' x: (Bs, Nt)
-> (Bs, Nt, Nf*2) if reduce_channel==True
-> (Bs, Nt, Nf,2) otherwise
'''
Bs, Nt = x.shape
n_ffts = window.size(-1)
n_freq = n_ffts // 2 + 1
hop_length = n_ffts // 4
s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
s = s.transpose(1,2)
if reduce_channel:
s = rearrange(s, 'b t f k -> b t (f k)',
b=Bs, f=n_freq, k=2)
return s
def from_spec(x, window, length):
''' x: (Bs, Nt, Nf*2)
-> (Bs, Nt)
'''
Bs, Nt, _ = x.shape
n_ffts = window.size(-1)
n_freq = n_ffts // 2 + 1
x = rearrange(x, 'b t (f k) -> b f t k', f=n_freq, k=2)
x = torch.istft(x, n_ffts, length=length, window=window)
return x
def adjust_gain(y, x, minmax, ref_dBFS=-23.0):
ran_gain = (minmax[1] - minmax[0]) * torch.rand_like(y.narrow(-1,0,1)) + minmax[0]
ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
ran_linear = torch.pow(10, (ran_gain-3.0103)/20.)
x_rms = calculate_rms(x)
y_rms = calculate_rms(y)
x_gain = ref_linear / (x_rms + eps)
y_gain = ref_linear / (y_rms + eps)
y_xscale = y * y_gain / x_gain
return y_xscale / ran_linear
def degrade(x, rir, noise):
''' x : (Bs, Nt)
rir : (Bs, Nt)
noise: (Bs, Nt)
'''
x_pad = F.pad(x, (0,rir.size(-1)))
w_pad = F.pad(rir, (0,rir.size(-1)))
x_fft = torch.fft.rfft(x_pad)
w_fft = torch.fft.rfft(w_pad)
wet_x = torch.fft.irfft(x_fft * w_fft).narrow(-1,0,x.size(-1))
y = adjust_gain(wet_x, x, [-0, 30]) # ser
n = adjust_gain(noise, y, [10, 30]) # snr
return y + n
def T60_to_sigma(T60, f_0, K):
''' T60 : (Bs, 2, 2) [[T60_freq_1, T60_1], [T60_freq_2, T60_2]]
f_0 : (Bs, Nt, 1) fundamental frequency
K : (Bs, Nt, 1) kappa (K == gamma * kappa_rel)
-> sig : (Bs, Nt, 2)
'''
gamma = f_0 * 2
freq1, time1 = T60.narrow(1,0,1).chunk(2,-1)
freq2, time2 = T60.narrow(1,1,1).chunk(2,-1)
zeta1 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq1).pow(2)).pow(.5)
zeta2 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq2).pow(2)).pow(.5)
sig0 = - zeta2 / time1 + zeta1 / time2
sig0 = 6 * math.log(10) * sig0 / (zeta1 - zeta2)
sig1 = 1 / time1 - 1 / time2
sig1 = 6 * math.log(10) * sig1 / (zeta1 - zeta2)
sig = torch.cat((sig0, sig1), dim=-1)
return sig