dmsp / src /utils /ddsp.py
szin94's picture
first commit
bc3e180
raw
history blame
4.65 kB
import torch
import torch.nn as nn
import torch.fft as fft
import numpy as np
import librosa as li
import math
def safe_log(x):
return torch.log(x + 1e-7)
@torch.no_grad()
def mean_std_loudness(dataset):
mean = 0
std = 0
n = 0
for _, _, l in dataset:
n += 1
mean += (l.mean().item() - mean) / n
std += (l.std().item() - std) / n
return mean, std
def multiscale_fft(signal, scales, overlap):
stfts = []
for s in scales:
S = torch.stft(
signal,
s,
int(s * (1 - overlap)),
s,
torch.hann_window(s).to(signal),
True,
normalized=True,
return_complex=True,
).abs()
stfts.append(S)
return stfts
def resample(x, factor: int):
batch, frame, channel = x.shape
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
window = torch.hann_window(
factor * 2,
dtype=x.dtype,
device=x.device,
).reshape(1, 1, -1)
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
y[..., ::factor] = x
y[..., -1:] = x[..., -1:]
y = torch.nn.functional.pad(y, [factor, factor])
y = torch.nn.functional.conv1d(y, window)[..., :-1]
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
return y
def upsample(signal, factor):
signal = signal.permute(0,2,1)
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear')
return signal.permute(0,2,1)
def remove_above_nyquist(amplitudes, pitch, sampling_rate):
''' amplitudes: (batch, frames, n_harmoincs)
pitch: (batch, frames, 1)
'''
n_harm = amplitudes.shape[-1]
pitches = pitch.repeat(1,1,n_harm).cumsum(-1)
aa = (pitches < sampling_rate / 2).float() + 1e-4
return amplitudes * aa
def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate):
''' amplitudes: (batch, frames, n_harmoincs)
frequencies: (batch, frames, n_harmonics)
'''
aa = (frequencies < sampling_rate / 2).float() + 1e-4
return amplitudes * aa
def scale_function(x):
''' 0 ~ 2'''
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
S = li.stft(
signal,
n_fft=n_fft,
hop_length=block_size,
win_length=n_fft,
center=True,
)
S = np.log(abs(S) + 1e-7)
f = li.fft_frequencies(sampling_rate, n_fft)
a_weight = li.A_weighting(f)
S = S + a_weight.reshape(-1, 1)
S = np.mean(S, 0)[..., :-1]
return S
def extract_pitch(signal, sampling_rate, block_size):
length = signal.shape[-1] // block_size
f0 = crepe.predict(
signal,
sampling_rate,
step_size=int(1000 * block_size / sampling_rate),
verbose=1,
center=True,
viterbi=True,
)
f0 = f0[1].reshape(-1)[:-1]
if f0.shape[-1] != length:
f0 = np.interp(
np.linspace(0, 1, length, endpoint=False),
np.linspace(0, 1, f0.shape[-1], endpoint=False),
f0,
)
return f0
def harmonic_synth(pitch, amplitudes, sampling_rate):
n_harmonic = amplitudes.shape[-1]
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
return signal
def modal_synth(modes, amplitude, sampling_rate, n_chunks=16):
freqs = modes.chunk(n_chunks, 1)
coefs = amplitude.chunk(n_chunks, 1)
lastf = torch.zeros_like(freqs[0])
sols = []
for f, c in zip(freqs, coefs):
fcs = f.cumsum(1) + lastf
sol = (torch.cos(fcs) * c).sum(-1, keepdim=True)
lastf = fcs.narrow(1,-1,1)
sols.append(sol)
return torch.cat(sols, 1)
def amp_to_impulse_response(amp, target_size):
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
amp = torch.view_as_complex(amp)
amp = fft.irfft(amp)
filter_size = amp.shape[-1]
amp = torch.roll(amp, filter_size // 2, -1)
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
amp = amp * win
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
amp = torch.roll(amp, -filter_size // 2, -1)
return amp
def fft_convolve(signal, kernel):
signal = nn.functional.pad(signal, (0, signal.shape[-1]))
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
output = output[..., output.shape[-1] // 2:]
return output