File size: 2,653 Bytes
bc3e180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.nn as nn
from src.model.nn.blocks import FMBlock, AMBlock
from src.utils.ddsp import upsample
from src.utils.ddsp import remove_above_nyquist_mode
from src.utils.ddsp import amp_to_impulse_response, fft_convolve
from src.utils.ddsp import modal_synth
from src.utils.ddsp import resample
import math
class DDSP(nn.Module):
def __init__(self,
feature_size, hidden_size,
n_modes, n_bands, sampling_rate, block_size,
fm=False,
):
super().__init__()
self.n_modes = n_modes
self.freq_modulator = FMBlock(n_modes, feature_size) if fm else None
self.coef_modulator = AMBlock(n_modes, feature_size)
self.noise_proj = nn.Linear(feature_size, n_bands)
noise_gate = nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.register_parameter("noise_gate", noise_gate)
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
self.register_buffer("block_size", torch.tensor(block_size))
def forward(self, hidden, mode_freq, mode_coef, times, alpha, lengths):
''' hidden : (Bs, 1, hidden_size)
mode_freq : (Bs, Nt, n_modes)
mode_coef : (Bs, 1, n_modes)
times : (Bs, Nt, 1)
'''
if self.freq_modulator is None:
freq_m = mode_freq # integer multiples
else:
freq_m = self.freq_modulator(mode_freq, hidden)
coef_m = self.coef_modulator(mode_coef, hidden, times)
#==============================
# harmonic part
#==============================
freqs = freq_m / (2*math.pi) * self.sampling_rate
coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)
#==============================
# noise part
#==============================
ngate = torch.tanh((alpha - 1) * self.noise_gate)
param = ngate * torch.sigmoid(self.noise_proj(hidden) - 5)
impulse = amp_to_impulse_response(param, self.block_size)
noise = torch.rand(
impulse.shape[0],
impulse.shape[1],
self.block_size,
).to(impulse) * 2 - 1
noise = fft_convolve(noise, impulse).contiguous()
noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)
signal = harmonic + noise
return signal.squeeze(-1), freq_m, coef_m
|