dmsp / src /utils /ddsp.py
szin94's picture
first commit
bc3e180
import torch
import torch.nn as nn
import torch.fft as fft
import numpy as np
import librosa as li
import math
def safe_log(x):
return torch.log(x + 1e-7)
@torch.no_grad()
def mean_std_loudness(dataset):
mean = 0
std = 0
n = 0
for _, _, l in dataset:
n += 1
mean += (l.mean().item() - mean) / n
std += (l.std().item() - std) / n
return mean, std
def multiscale_fft(signal, scales, overlap):
stfts = []
for s in scales:
S = torch.stft(
signal,
s,
int(s * (1 - overlap)),
s,
torch.hann_window(s).to(signal),
True,
normalized=True,
return_complex=True,
).abs()
stfts.append(S)
return stfts
def resample(x, factor: int):
batch, frame, channel = x.shape
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
window = torch.hann_window(
factor * 2,
dtype=x.dtype,
device=x.device,
).reshape(1, 1, -1)
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
y[..., ::factor] = x
y[..., -1:] = x[..., -1:]
y = torch.nn.functional.pad(y, [factor, factor])
y = torch.nn.functional.conv1d(y, window)[..., :-1]
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
return y
def upsample(signal, factor):
signal = signal.permute(0,2,1)
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear')
return signal.permute(0,2,1)
def remove_above_nyquist(amplitudes, pitch, sampling_rate):
''' amplitudes: (batch, frames, n_harmoincs)
pitch: (batch, frames, 1)
'''
n_harm = amplitudes.shape[-1]
pitches = pitch.repeat(1,1,n_harm).cumsum(-1)
aa = (pitches < sampling_rate / 2).float() + 1e-4
return amplitudes * aa
def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate):
''' amplitudes: (batch, frames, n_harmoincs)
frequencies: (batch, frames, n_harmonics)
'''
aa = (frequencies < sampling_rate / 2).float() + 1e-4
return amplitudes * aa
def scale_function(x):
''' 0 ~ 2'''
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
S = li.stft(
signal,
n_fft=n_fft,
hop_length=block_size,
win_length=n_fft,
center=True,
)
S = np.log(abs(S) + 1e-7)
f = li.fft_frequencies(sampling_rate, n_fft)
a_weight = li.A_weighting(f)
S = S + a_weight.reshape(-1, 1)
S = np.mean(S, 0)[..., :-1]
return S
def extract_pitch(signal, sampling_rate, block_size):
length = signal.shape[-1] // block_size
f0 = crepe.predict(
signal,
sampling_rate,
step_size=int(1000 * block_size / sampling_rate),
verbose=1,
center=True,
viterbi=True,
)
f0 = f0[1].reshape(-1)[:-1]
if f0.shape[-1] != length:
f0 = np.interp(
np.linspace(0, 1, length, endpoint=False),
np.linspace(0, 1, f0.shape[-1], endpoint=False),
f0,
)
return f0
def harmonic_synth(pitch, amplitudes, sampling_rate):
n_harmonic = amplitudes.shape[-1]
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
return signal
def modal_synth(modes, amplitude, sampling_rate, n_chunks=16):
freqs = modes.chunk(n_chunks, 1)
coefs = amplitude.chunk(n_chunks, 1)
lastf = torch.zeros_like(freqs[0])
sols = []
for f, c in zip(freqs, coefs):
fcs = f.cumsum(1) + lastf
sol = (torch.cos(fcs) * c).sum(-1, keepdim=True)
lastf = fcs.narrow(1,-1,1)
sols.append(sol)
return torch.cat(sols, 1)
def amp_to_impulse_response(amp, target_size):
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
amp = torch.view_as_complex(amp)
amp = fft.irfft(amp)
filter_size = amp.shape[-1]
amp = torch.roll(amp, filter_size // 2, -1)
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
amp = amp * win
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
amp = torch.roll(amp, -filter_size // 2, -1)
return amp
def fft_convolve(signal, kernel):
signal = nn.functional.pad(signal, (0, signal.shape[-1]))
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
output = output[..., output.shape[-1] // 2:]
return output