akhaliq3
spaces demo
607ecc1
raw
history blame contribute delete
No virus
2.28 kB
import math
from typing import Callable
import gin
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
@gin.configurable
class FIRNoiseSynth(nn.Module):
def __init__(
self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window
):
super().__init__()
self.ir_length = ir_length
self.hop_length = hop_length
self.register_buffer("window", window_fn(ir_length))
def forward(self, H_re):
H_im = torch.zeros_like(H_re)
H_z = torch.complex(H_re, H_im)
h = torch.fft.irfft(H_z.transpose(1, 2))
h = h.roll(self.ir_length // 2, -1)
h = h * self.window.view(1, 1, -1)
H = torch.fft.rfft(h)
noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)
X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)
X = X.unsqueeze(0)
Y = X * H.transpose(1, 2)
y = torch.istft(Y, self.ir_length, self.hop_length, center=False)
return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]
@gin.configurable
class HarmonicOscillator(nn.Module):
def __init__(self, n_harmonics, sample_rate):
super().__init__()
self.sample_rate = sample_rate
self.n_harmonics = n_harmonics
self.register_buffer("harmonic_axis", self._create_harmonic_axis(n_harmonics))
self.register_buffer("rand_phase", torch.ones(1, n_harmonics, 1) * math.tau)
def _create_harmonic_axis(self, n_harmonics):
return torch.arange(1, n_harmonics + 1).view(1, -1, 1)
def _create_antialias_mask(self, f0):
freqs = f0.unsqueeze(1) * self.harmonic_axis
return freqs < (self.sample_rate / 2)
def _create_phase_shift(self, n_harmonics):
shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi
return shift
def forward(self, f0):
phase = math.tau * f0.cumsum(-1) / self.sample_rate
harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)
harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)
antialias_mask = self._create_antialias_mask(f0)
output = torch.sin(harmonic_phase) * antialias_mask
return output