File size: 2,278 Bytes
607ecc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import math
from typing import Callable

import gin
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F


@gin.configurable
class FIRNoiseSynth(nn.Module):
    def __init__(
        self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window
    ):
        super().__init__()
        self.ir_length = ir_length
        self.hop_length = hop_length
        self.register_buffer("window", window_fn(ir_length))

    def forward(self, H_re):
        H_im = torch.zeros_like(H_re)
        H_z = torch.complex(H_re, H_im)

        h = torch.fft.irfft(H_z.transpose(1, 2))
        h = h.roll(self.ir_length // 2, -1)
        h = h * self.window.view(1, 1, -1)
        H = torch.fft.rfft(h)

        noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)
        X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)
        X = X.unsqueeze(0)
        Y = X * H.transpose(1, 2)
        y = torch.istft(Y, self.ir_length, self.hop_length, center=False)
        return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]


@gin.configurable
class HarmonicOscillator(nn.Module):
    def __init__(self, n_harmonics, sample_rate):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics
        self.register_buffer("harmonic_axis", self._create_harmonic_axis(n_harmonics))
        self.register_buffer("rand_phase", torch.ones(1, n_harmonics, 1) * math.tau)

    def _create_harmonic_axis(self, n_harmonics):
        return torch.arange(1, n_harmonics + 1).view(1, -1, 1)

    def _create_antialias_mask(self, f0):
        freqs = f0.unsqueeze(1) * self.harmonic_axis
        return freqs < (self.sample_rate / 2)

    def _create_phase_shift(self, n_harmonics):
        shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi
        return shift

    def forward(self, f0):
        phase = math.tau * f0.cumsum(-1) / self.sample_rate
        harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)
        harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)
        antialias_mask = self._create_antialias_mask(f0)

        output = torch.sin(harmonic_phase) * antialias_mask

        return output