File size: 2,027 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch


class FIRFilter(torch.nn.Module):
    def __init__(self, num_control_params=63):
        super().__init__()
        self.num_control_params = num_control_params
        self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
        #self.batched_lfilter = torch.vmap(self.lfilter)

    def forward(self, x, b, **kwargs):
        """Forward pass by appling FIR filter to each batch element.

        Args:
            x (tensor): Input signals with shape (batch x 1 x samples)
            b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)

        """
        bs, ch, s = x.size()
        b = self.adaptor(b)

        # pad input
        x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2))

        # add extra dim for virutal batch dim
        x = x.view(bs, 1, ch, -1)
        b = b.view(bs, 1, 1, -1)

        # exlcuding vmap for now
        y = self.batched_lfilter(x, b).view(bs, ch, s)

        return y

    @staticmethod
    def lfilter(x, b):
        return torch.nn.functional.conv1d(x, b)


class FrequencyDomainFIRFilter(torch.nn.Module):
    def __init__(self, num_control_params=31):
        super().__init__()
        self.num_control_params = num_control_params
        self.adaptor = torch.nn.Linear(num_control_params, num_control_params)

    def forward(self, x, b, **kwargs):
        """Forward pass by appling FIR filter to each batch element.

        Args:
            x (tensor): Input signals with shape (batch x 1 x samples)
            b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
        """
        bs, c, s = x.size()

        b = self.adaptor(b)

        # transform input to freq. domain
        X = torch.fft.rfft(x.view(bs, -1))

        # frequency response of filter
        H = torch.fft.rfft(b.view(bs, -1))

        # apply filter as multiplication in freq. domain
        Y = X * H

        # transform back to time domain
        y = torch.fft.ifft(Y).view(bs, 1, -1)

        return y