File size: 2,653 Bytes
bc3e180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
from src.model.nn.blocks import FMBlock, AMBlock
from src.utils.ddsp import upsample
from src.utils.ddsp import remove_above_nyquist_mode
from src.utils.ddsp import amp_to_impulse_response, fft_convolve
from src.utils.ddsp import modal_synth
from src.utils.ddsp import resample
import math

class DDSP(nn.Module):
    def __init__(self,
            feature_size, hidden_size,
            n_modes, n_bands, sampling_rate, block_size,
            fm=False,
        ):
        super().__init__()
        self.n_modes = n_modes

        self.freq_modulator = FMBlock(n_modes, feature_size) if fm else None
        self.coef_modulator = AMBlock(n_modes, feature_size)
        self.noise_proj = nn.Linear(feature_size, n_bands)

        noise_gate = nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
        self.register_parameter("noise_gate", noise_gate)
        self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
        self.register_buffer("block_size", torch.tensor(block_size))

    def forward(self, hidden, mode_freq, mode_coef, times, alpha, lengths):
        ''' hidden    : (Bs,  1, hidden_size)
            mode_freq : (Bs, Nt, n_modes)
            mode_coef : (Bs,  1, n_modes)
            times     : (Bs, Nt, 1)
        '''
        if self.freq_modulator is None:
            freq_m = mode_freq # integer multiples
        else:
            freq_m = self.freq_modulator(mode_freq, hidden)
        coef_m = self.coef_modulator(mode_coef, hidden, times)

        #============================== 
        # harmonic part
        #============================== 
        freqs = freq_m / (2*math.pi) * self.sampling_rate
        coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
        freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
        coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
        harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)

        #============================== 
        # noise part
        #============================== 
        ngate = torch.tanh((alpha - 1) * self.noise_gate)
        param = ngate * torch.sigmoid(self.noise_proj(hidden) - 5)

        impulse = amp_to_impulse_response(param, self.block_size)
        noise = torch.rand(
            impulse.shape[0],
            impulse.shape[1],
            self.block_size,
        ).to(impulse) * 2 - 1
        noise = fft_convolve(noise, impulse).contiguous()
        noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)

        signal = harmonic + noise
        return signal.squeeze(-1), freq_m, coef_m