|
import torch |
|
import torch.nn as nn |
|
import torch.fft as fft |
|
import numpy as np |
|
import librosa as li |
|
import math |
|
|
|
def safe_log(x): |
|
return torch.log(x + 1e-7) |
|
|
|
@torch.no_grad() |
|
def mean_std_loudness(dataset): |
|
mean = 0 |
|
std = 0 |
|
n = 0 |
|
for _, _, l in dataset: |
|
n += 1 |
|
mean += (l.mean().item() - mean) / n |
|
std += (l.std().item() - std) / n |
|
return mean, std |
|
|
|
|
|
def multiscale_fft(signal, scales, overlap): |
|
stfts = [] |
|
for s in scales: |
|
S = torch.stft( |
|
signal, |
|
s, |
|
int(s * (1 - overlap)), |
|
s, |
|
torch.hann_window(s).to(signal), |
|
True, |
|
normalized=True, |
|
return_complex=True, |
|
).abs() |
|
stfts.append(S) |
|
return stfts |
|
|
|
|
|
def resample(x, factor: int): |
|
batch, frame, channel = x.shape |
|
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame) |
|
|
|
window = torch.hann_window( |
|
factor * 2, |
|
dtype=x.dtype, |
|
device=x.device, |
|
).reshape(1, 1, -1) |
|
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x) |
|
y[..., ::factor] = x |
|
y[..., -1:] = x[..., -1:] |
|
y = torch.nn.functional.pad(y, [factor, factor]) |
|
y = torch.nn.functional.conv1d(y, window)[..., :-1] |
|
|
|
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1) |
|
|
|
return y |
|
|
|
|
|
|
|
def upsample(signal, factor): |
|
signal = signal.permute(0,2,1) |
|
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear') |
|
return signal.permute(0,2,1) |
|
|
|
|
|
def remove_above_nyquist(amplitudes, pitch, sampling_rate): |
|
''' amplitudes: (batch, frames, n_harmoincs) |
|
pitch: (batch, frames, 1) |
|
''' |
|
n_harm = amplitudes.shape[-1] |
|
pitches = pitch.repeat(1,1,n_harm).cumsum(-1) |
|
aa = (pitches < sampling_rate / 2).float() + 1e-4 |
|
return amplitudes * aa |
|
|
|
|
|
def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate): |
|
''' amplitudes: (batch, frames, n_harmoincs) |
|
frequencies: (batch, frames, n_harmonics) |
|
''' |
|
aa = (frequencies < sampling_rate / 2).float() + 1e-4 |
|
return amplitudes * aa |
|
|
|
def scale_function(x): |
|
''' 0 ~ 2''' |
|
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7 |
|
|
|
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): |
|
S = li.stft( |
|
signal, |
|
n_fft=n_fft, |
|
hop_length=block_size, |
|
win_length=n_fft, |
|
center=True, |
|
) |
|
S = np.log(abs(S) + 1e-7) |
|
f = li.fft_frequencies(sampling_rate, n_fft) |
|
a_weight = li.A_weighting(f) |
|
|
|
S = S + a_weight.reshape(-1, 1) |
|
|
|
S = np.mean(S, 0)[..., :-1] |
|
|
|
return S |
|
|
|
|
|
def extract_pitch(signal, sampling_rate, block_size): |
|
length = signal.shape[-1] // block_size |
|
f0 = crepe.predict( |
|
signal, |
|
sampling_rate, |
|
step_size=int(1000 * block_size / sampling_rate), |
|
verbose=1, |
|
center=True, |
|
viterbi=True, |
|
) |
|
f0 = f0[1].reshape(-1)[:-1] |
|
|
|
if f0.shape[-1] != length: |
|
f0 = np.interp( |
|
np.linspace(0, 1, length, endpoint=False), |
|
np.linspace(0, 1, f0.shape[-1], endpoint=False), |
|
f0, |
|
) |
|
|
|
return f0 |
|
|
|
|
|
def harmonic_synth(pitch, amplitudes, sampling_rate): |
|
n_harmonic = amplitudes.shape[-1] |
|
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1) |
|
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) |
|
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True) |
|
return signal |
|
|
|
def modal_synth(modes, amplitude, sampling_rate, n_chunks=16): |
|
freqs = modes.chunk(n_chunks, 1) |
|
coefs = amplitude.chunk(n_chunks, 1) |
|
lastf = torch.zeros_like(freqs[0]) |
|
sols = [] |
|
for f, c in zip(freqs, coefs): |
|
fcs = f.cumsum(1) + lastf |
|
sol = (torch.cos(fcs) * c).sum(-1, keepdim=True) |
|
lastf = fcs.narrow(1,-1,1) |
|
sols.append(sol) |
|
return torch.cat(sols, 1) |
|
|
|
|
|
def amp_to_impulse_response(amp, target_size): |
|
amp = torch.stack([amp, torch.zeros_like(amp)], -1) |
|
amp = torch.view_as_complex(amp) |
|
amp = fft.irfft(amp) |
|
|
|
filter_size = amp.shape[-1] |
|
|
|
amp = torch.roll(amp, filter_size // 2, -1) |
|
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) |
|
|
|
amp = amp * win |
|
|
|
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size))) |
|
amp = torch.roll(amp, -filter_size // 2, -1) |
|
|
|
return amp |
|
|
|
|
|
def fft_convolve(signal, kernel): |
|
signal = nn.functional.pad(signal, (0, signal.shape[-1])) |
|
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) |
|
|
|
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) |
|
output = output[..., output.shape[-1] // 2:] |
|
|
|
return output |
|
|
|
|