|
import os |
|
import numpy as np |
|
import yaml |
|
import torch |
|
import torch.nn.functional as F |
|
from librosa.filters import mel as librosa_mel_fn |
|
from .mel2control import Mel2Control |
|
from .core import frequency_filter, upsample, remove_above_fmax |
|
|
|
class DotDict(dict): |
|
def __getattr__(*args): |
|
val = dict.get(*args) |
|
return DotDict(val) if type(val) is dict else val |
|
|
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
def load_model( |
|
model_path, |
|
device='cpu'): |
|
config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml') |
|
with open(config_file, "r") as config: |
|
args = yaml.safe_load(config) |
|
args = DotDict(args) |
|
|
|
|
|
print(' [Loading] ' + model_path) |
|
if model_path.split('.')[-1] == 'jit': |
|
model = torch.jit.load(model_path, map_location=torch.device(device)) |
|
else: |
|
if args.model.type == 'Sins': |
|
model = Sins( |
|
sampling_rate=args.data.sampling_rate, |
|
block_size=args.data.block_size, |
|
win_length=args.data.n_fft, |
|
n_harmonics=args.model.n_harmonics, |
|
n_mag_noise=args.model.n_mag_noise, |
|
n_mels=args.data.n_mels) |
|
|
|
elif args.model.type == 'CombSub': |
|
model = CombSub( |
|
sampling_rate=args.data.sampling_rate, |
|
block_size=args.data.block_size, |
|
win_length=args.data.n_fft, |
|
n_mag_harmonic=args.model.n_mag_harmonic, |
|
n_mag_noise=args.model.n_mag_noise, |
|
n_mels=args.data.n_mels) |
|
|
|
else: |
|
raise ValueError(f" [x] Unknown Model: {args.model.type}") |
|
model.to(device) |
|
ckpt = torch.load(model_path, map_location=torch.device(device)) |
|
model.load_state_dict(ckpt['model']) |
|
model.eval() |
|
return model, args |
|
|
|
class Audio2Mel(torch.nn.Module): |
|
def __init__( |
|
self, |
|
hop_length, |
|
sampling_rate, |
|
n_mel_channels, |
|
win_length, |
|
n_fft=None, |
|
mel_fmin=0, |
|
mel_fmax=None, |
|
clamp = 1e-5 |
|
): |
|
super().__init__() |
|
n_fft = win_length if n_fft is None else n_fft |
|
self.hann_window = {} |
|
mel_basis = librosa_mel_fn( |
|
sr=sampling_rate, |
|
n_fft=n_fft, |
|
n_mels=n_mel_channels, |
|
fmin=mel_fmin, |
|
fmax=mel_fmax) |
|
mel_basis = torch.from_numpy(mel_basis).float() |
|
self.register_buffer("mel_basis", mel_basis) |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.win_length = win_length |
|
self.sampling_rate = sampling_rate |
|
self.n_mel_channels = n_mel_channels |
|
self.clamp = clamp |
|
|
|
def forward(self, audio, keyshift=0, speed=1): |
|
''' |
|
audio: B x C x T |
|
log_mel_spec: B x T_ x C x n_mel |
|
''' |
|
factor = 2 ** (keyshift / 12) |
|
n_fft_new = int(np.round(self.n_fft * factor)) |
|
win_length_new = int(np.round(self.win_length * factor)) |
|
hop_length_new = int(np.round(self.hop_length * speed)) |
|
|
|
keyshift_key = str(keyshift)+'_'+str(audio.device) |
|
if keyshift_key not in self.hann_window: |
|
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) |
|
|
|
B, C, T = audio.shape |
|
audio = audio.reshape(B * C, T) |
|
fft = torch.stft( |
|
audio, |
|
n_fft=n_fft_new, |
|
hop_length=hop_length_new, |
|
win_length=win_length_new, |
|
window=self.hann_window[keyshift_key], |
|
center=True, |
|
return_complex=True) |
|
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) |
|
|
|
if keyshift != 0: |
|
size = self.n_fft // 2 + 1 |
|
resize = magnitude.size(1) |
|
if resize < size: |
|
magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) |
|
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new |
|
|
|
mel_output = torch.matmul(self.mel_basis, magnitude) |
|
log_mel_spec = torch.log10(torch.clamp(mel_output, min=self.clamp)) |
|
|
|
|
|
T_ = log_mel_spec.shape[-1] |
|
log_mel_spec = log_mel_spec.reshape(B, C, self.n_mel_channels ,T_) |
|
log_mel_spec = log_mel_spec.permute(0, 3, 1, 2) |
|
|
|
|
|
log_mel_spec = log_mel_spec.squeeze(2) |
|
return log_mel_spec |
|
|
|
class Sins(torch.nn.Module): |
|
def __init__(self, |
|
sampling_rate, |
|
block_size, |
|
win_length, |
|
n_harmonics, |
|
n_mag_noise, |
|
n_mels=80): |
|
super().__init__() |
|
|
|
print(' [DDSP Model] Sinusoids Additive Synthesiser') |
|
|
|
|
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) |
|
self.register_buffer("block_size", torch.tensor(block_size)) |
|
self.register_buffer("win_length", torch.tensor(win_length)) |
|
self.register_buffer("window", torch.hann_window(win_length)) |
|
|
|
split_map = { |
|
'harmonic_phase': win_length // 2 + 1, |
|
'amplitudes': n_harmonics, |
|
'noise_magnitude': n_mag_noise, |
|
} |
|
self.mel2ctrl = Mel2Control(n_mels, split_map) |
|
|
|
def forward(self, mel_frames, f0_frames, initial_phase=None, infer=True, max_upsample_dim=32): |
|
''' |
|
mel_frames: B x n_frames x n_mels |
|
f0_frames: B x n_frames x 1 |
|
''' |
|
|
|
f0 = upsample(f0_frames, self.block_size) |
|
if infer: |
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) |
|
else: |
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1) |
|
if initial_phase is not None: |
|
x += initial_phase.to(x) / 2 / np.pi |
|
x = x - torch.round(x) |
|
x = x.to(f0) |
|
|
|
phase = 2 * np.pi * x |
|
phase_frames = phase[:, ::self.block_size, :] |
|
|
|
|
|
ctrls = self.mel2ctrl(mel_frames, phase_frames) |
|
|
|
src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase']) |
|
src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1) |
|
amplitudes_frames = torch.exp(ctrls['amplitudes'])/ 128 |
|
noise_param = torch.exp(ctrls['noise_magnitude']) / 128 |
|
|
|
|
|
amplitudes_frames = remove_above_fmax(amplitudes_frames, f0_frames, self.sampling_rate / 2, level_start = 1) |
|
n_harmonic = amplitudes_frames.shape[-1] |
|
level_harmonic = torch.arange(1, n_harmonic + 1).to(phase) |
|
sinusoids = 0. |
|
for n in range(( n_harmonic - 1) // max_upsample_dim + 1): |
|
start = n * max_upsample_dim |
|
end = (n + 1) * max_upsample_dim |
|
phases = phase * level_harmonic[start:end] |
|
amplitudes = upsample(amplitudes_frames[:,:,start:end], self.block_size) |
|
sinusoids += (torch.sin(phases) * amplitudes).sum(-1) |
|
|
|
|
|
harmonic_spec = torch.stft( |
|
sinusoids, |
|
n_fft = self.win_length, |
|
win_length = self.win_length, |
|
hop_length = self.block_size, |
|
window = self.window, |
|
center = True, |
|
return_complex = True) |
|
harmonic_spec = harmonic_spec * src_allpass.permute(0, 2, 1) |
|
harmonic = torch.istft( |
|
harmonic_spec, |
|
n_fft = self.win_length, |
|
win_length = self.win_length, |
|
hop_length = self.block_size, |
|
window = self.window, |
|
center = True) |
|
|
|
|
|
noise = torch.rand_like(harmonic).to(noise_param) * 2 - 1 |
|
noise = frequency_filter( |
|
noise, |
|
torch.complex(noise_param, torch.zeros_like(noise_param)), |
|
hann_window = True) |
|
|
|
signal = harmonic + noise |
|
|
|
return signal, phase, (harmonic, noise) |
|
|
|
class CombSub(torch.nn.Module): |
|
def __init__(self, |
|
sampling_rate, |
|
block_size, |
|
win_length, |
|
n_mag_harmonic, |
|
n_mag_noise, |
|
n_mels=80): |
|
super().__init__() |
|
|
|
print(' [DDSP Model] Combtooth Subtractive Synthesiser') |
|
|
|
self.register_buffer("sampling_rate", torch.tensor(sampling_rate)) |
|
self.register_buffer("block_size", torch.tensor(block_size)) |
|
self.register_buffer("win_length", torch.tensor(win_length)) |
|
self.register_buffer("window", torch.hann_window(win_length)) |
|
|
|
split_map = { |
|
'harmonic_phase': win_length // 2 + 1, |
|
'harmonic_magnitude': n_mag_harmonic, |
|
'noise_magnitude': n_mag_noise |
|
} |
|
self.mel2ctrl = Mel2Control(n_mels, split_map) |
|
|
|
def forward(self, mel_frames, f0_frames, initial_phase=None, infer=True, **kwargs): |
|
''' |
|
mel_frames: B x n_frames x n_mels |
|
f0_frames: B x n_frames x 1 |
|
''' |
|
|
|
f0 = upsample(f0_frames, self.block_size) |
|
if infer: |
|
x = torch.cumsum(f0.double() / self.sampling_rate, axis=1) |
|
else: |
|
x = torch.cumsum(f0 / self.sampling_rate, axis=1) |
|
if initial_phase is not None: |
|
x += initial_phase.to(x) / 2 / np.pi |
|
|
|
x = x - torch.round(x) |
|
x = x.to(f0) |
|
|
|
phase_frames = 2 * np.pi * x[:, ::self.block_size, :] |
|
|
|
|
|
ctrls = self.mel2ctrl(mel_frames, phase_frames) |
|
|
|
|
|
src_allpass = torch.exp(1.j * np.pi * ctrls['harmonic_phase']) |
|
src_allpass = torch.cat((src_allpass, src_allpass[:,-1:,:]), 1) |
|
src_param = torch.exp(ctrls['harmonic_magnitude']) |
|
noise_param = torch.exp(ctrls['noise_magnitude']) / 128 |
|
|
|
|
|
combtooth = torch.sinc(self.sampling_rate * x / (f0 + 1e-3)) |
|
combtooth = combtooth.squeeze(-1) |
|
|
|
|
|
pass |
|
harmonic = frequency_filter( |
|
combtooth, |
|
torch.complex(src_param, torch.zeros_like(src_param)), |
|
hann_window = True, |
|
half_width_frames = 1.5 * self.sampling_rate / (f0_frames + 1e-3)) |
|
|
|
|
|
harmonic_spec = torch.stft( |
|
harmonic, |
|
n_fft = self.win_length, |
|
win_length = self.win_length, |
|
hop_length = self.block_size, |
|
window = self.window, |
|
center = True, |
|
return_complex = True) |
|
harmonic_spec = harmonic_spec * src_allpass.permute(0, 2, 1) |
|
|
|
harmonic = torch.istft( |
|
harmonic_spec, |
|
n_fft = self.win_length, |
|
win_length = self.win_length, |
|
hop_length = self.block_size, |
|
window = self.window, |
|
center = True) |
|
|
|
|
|
noise = torch.rand_like(harmonic).to(noise_param) * 2 - 1 |
|
noise = frequency_filter( |
|
noise, |
|
torch.complex(noise_param, torch.zeros_like(noise_param)), |
|
hann_window = True) |
|
|
|
signal = harmonic + noise |
|
|
|
return signal, phase_frames, (harmonic, noise) |