akhaliq3
spaces demo
607ecc1
raw
history blame
No virus
2.28 kB
import math
from typing import Callable
import gin
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
@gin.configurable
class FIRNoiseSynth(nn.Module):
def __init__(
self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window
):
super().__init__()
self.ir_length = ir_length
self.hop_length = hop_length
self.register_buffer("window", window_fn(ir_length))
def forward(self, H_re):
H_im = torch.zeros_like(H_re)
H_z = torch.complex(H_re, H_im)
h = torch.fft.irfft(H_z.transpose(1, 2))
h = h.roll(self.ir_length // 2, -1)
h = h * self.window.view(1, 1, -1)
H = torch.fft.rfft(h)
noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)
X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)
X = X.unsqueeze(0)
Y = X * H.transpose(1, 2)
y = torch.istft(Y, self.ir_length, self.hop_length, center=False)
return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]
@gin.configurable
class HarmonicOscillator(nn.Module):
def __init__(self, n_harmonics, sample_rate):
super().__init__()
self.sample_rate = sample_rate
self.n_harmonics = n_harmonics
self.register_buffer("harmonic_axis", self._create_harmonic_axis(n_harmonics))
self.register_buffer("rand_phase", torch.ones(1, n_harmonics, 1) * math.tau)
def _create_harmonic_axis(self, n_harmonics):
return torch.arange(1, n_harmonics + 1).view(1, -1, 1)
def _create_antialias_mask(self, f0):
freqs = f0.unsqueeze(1) * self.harmonic_axis
return freqs < (self.sample_rate / 2)
def _create_phase_shift(self, n_harmonics):
shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi
return shift
def forward(self, f0):
phase = math.tau * f0.cumsum(-1) / self.sample_rate
harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)
harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)
antialias_mask = self._create_antialias_mask(f0)
output = torch.sin(harmonic_phase) * antialias_mask
return output