import torch import torch.nn as nn import torch.fft as fft import numpy as np import librosa as li import math def safe_log(x): return torch.log(x + 1e-7) @torch.no_grad() def mean_std_loudness(dataset): mean = 0 std = 0 n = 0 for _, _, l in dataset: n += 1 mean += (l.mean().item() - mean) / n std += (l.std().item() - std) / n return mean, std def multiscale_fft(signal, scales, overlap): stfts = [] for s in scales: S = torch.stft( signal, s, int(s * (1 - overlap)), s, torch.hann_window(s).to(signal), True, normalized=True, return_complex=True, ).abs() stfts.append(S) return stfts def resample(x, factor: int): batch, frame, channel = x.shape x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame) window = torch.hann_window( factor * 2, dtype=x.dtype, device=x.device, ).reshape(1, 1, -1) y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x) y[..., ::factor] = x y[..., -1:] = x[..., -1:] y = torch.nn.functional.pad(y, [factor, factor]) y = torch.nn.functional.conv1d(y, window)[..., :-1] y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1) return y def upsample(signal, factor): signal = signal.permute(0,2,1) signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear') return signal.permute(0,2,1) def remove_above_nyquist(amplitudes, pitch, sampling_rate): ''' amplitudes: (batch, frames, n_harmoincs) pitch: (batch, frames, 1) ''' n_harm = amplitudes.shape[-1] pitches = pitch.repeat(1,1,n_harm).cumsum(-1) aa = (pitches < sampling_rate / 2).float() + 1e-4 return amplitudes * aa def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate): ''' amplitudes: (batch, frames, n_harmoincs) frequencies: (batch, frames, n_harmonics) ''' aa = (frequencies < sampling_rate / 2).float() + 1e-4 return amplitudes * aa def scale_function(x): ''' 0 ~ 2''' return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7 def extract_loudness(signal, sampling_rate, block_size, n_fft=2048): S = li.stft( signal, n_fft=n_fft, hop_length=block_size, win_length=n_fft, center=True, ) S = np.log(abs(S) + 1e-7) f = li.fft_frequencies(sampling_rate, n_fft) a_weight = li.A_weighting(f) S = S + a_weight.reshape(-1, 1) S = np.mean(S, 0)[..., :-1] return S def extract_pitch(signal, sampling_rate, block_size): length = signal.shape[-1] // block_size f0 = crepe.predict( signal, sampling_rate, step_size=int(1000 * block_size / sampling_rate), verbose=1, center=True, viterbi=True, ) f0 = f0[1].reshape(-1)[:-1] if f0.shape[-1] != length: f0 = np.interp( np.linspace(0, 1, length, endpoint=False), np.linspace(0, 1, f0.shape[-1], endpoint=False), f0, ) return f0 def harmonic_synth(pitch, amplitudes, sampling_rate): n_harmonic = amplitudes.shape[-1] omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1) omegas = omega * torch.arange(1, n_harmonic + 1).to(omega) signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True) return signal def modal_synth(modes, amplitude, sampling_rate, n_chunks=16): freqs = modes.chunk(n_chunks, 1) coefs = amplitude.chunk(n_chunks, 1) lastf = torch.zeros_like(freqs[0]) sols = [] for f, c in zip(freqs, coefs): fcs = f.cumsum(1) + lastf sol = (torch.cos(fcs) * c).sum(-1, keepdim=True) lastf = fcs.narrow(1,-1,1) sols.append(sol) return torch.cat(sols, 1) def amp_to_impulse_response(amp, target_size): amp = torch.stack([amp, torch.zeros_like(amp)], -1) amp = torch.view_as_complex(amp) amp = fft.irfft(amp) filter_size = amp.shape[-1] amp = torch.roll(amp, filter_size // 2, -1) win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) amp = amp * win amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size))) amp = torch.roll(amp, -filter_size // 2, -1) return amp def fft_convolve(signal, kernel): signal = nn.functional.pad(signal, (0, signal.shape[-1])) kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) output = output[..., output.shape[-1] // 2:] return output