yourusername's picture
:beers: cheers
66a6dc0
import torch
class FIRFilter(torch.nn.Module):
def __init__(self, num_control_params=63):
super().__init__()
self.num_control_params = num_control_params
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
#self.batched_lfilter = torch.vmap(self.lfilter)
def forward(self, x, b, **kwargs):
"""Forward pass by appling FIR filter to each batch element.
Args:
x (tensor): Input signals with shape (batch x 1 x samples)
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
"""
bs, ch, s = x.size()
b = self.adaptor(b)
# pad input
x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2))
# add extra dim for virutal batch dim
x = x.view(bs, 1, ch, -1)
b = b.view(bs, 1, 1, -1)
# exlcuding vmap for now
y = self.batched_lfilter(x, b).view(bs, ch, s)
return y
@staticmethod
def lfilter(x, b):
return torch.nn.functional.conv1d(x, b)
class FrequencyDomainFIRFilter(torch.nn.Module):
def __init__(self, num_control_params=31):
super().__init__()
self.num_control_params = num_control_params
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
def forward(self, x, b, **kwargs):
"""Forward pass by appling FIR filter to each batch element.
Args:
x (tensor): Input signals with shape (batch x 1 x samples)
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
"""
bs, c, s = x.size()
b = self.adaptor(b)
# transform input to freq. domain
X = torch.fft.rfft(x.view(bs, -1))
# frequency response of filter
H = torch.fft.rfft(b.view(bs, -1))
# apply filter as multiplication in freq. domain
Y = X * H
# transform back to time domain
y = torch.fft.ifft(Y).view(bs, 1, -1)
return y