File size: 6,480 Bytes
7bc29af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
# Author: adefossez, 2020

"""
Implementation of a FFT based 1D convolution in PyTorch.
While FFT is used in CUDNN for small kernel sizes, it is not the case for long ones, e.g. 512.
This module implements efficient FFT based convolutions for such convolutions. A typical
application is for evaluationg FIR filters with a long receptive field, typically
evaluated with a stride of 1.
"""
from typing import Optional

import torch
try:
    import torch.fft as new_fft
except ImportError:
    new_fft = None  # type: ignore
from torch.nn import functional as F

from .core import pad_to, unfold
from .utils import simple_repr


# This is quite verbose, but sadly needed to make TorchScript happy.
def _new_rfft(x: torch.Tensor):
    z = new_fft.rfft(x, dim=-1)
    return torch.view_as_real(z)


def _old_rfft(x: torch.Tensor):
    return torch.rfft(x, 1)  # type: ignore


def _old_irfft(x: torch.Tensor, length: int):
    result = torch.irfft(x, 1, signal_sizes=(length,))  # type: ignore
    return result


def _new_irfft(x: torch.Tensor, length: int):
    x = torch.view_as_complex(x)
    return new_fft.irfft(x, length, dim=-1)


if new_fft is None:
    _rfft = _old_rfft
    _irfft = _old_irfft
else:
    _rfft = _new_rfft
    _irfft = _new_irfft


def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor):
    """
    Given a and b two tensors of dimension 4
    with the last dimension being the real and imaginary part,
    returns a multiplied by the conjugate of b, the multiplication
    being with respect to the second dimension.

    """
    # PyTorch 1.7 supports complex number, but not for all operations.
    # Once the support is widespread, this can likely go away.

    op = "bcft,dct->bdft"
    return torch.stack([
        torch.einsum(op, a[..., 0], b[..., 0]) + torch.einsum(op, a[..., 1], b[..., 1]),
        torch.einsum(op, a[..., 1], b[..., 0]) - torch.einsum(op, a[..., 0], b[..., 1])
    ],
                       dim=-1)


def fft_conv1d(
        input: torch.Tensor, weight: torch.Tensor,
        bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0,
        block_ratio: float = 5):
    """
    Same as `torch.nn.functional.conv1d` but using FFT for the convolution.
    Please check PyTorch documentation for more information.

    Args:
        input (Tensor): input signal of shape `[B, C, T]`.
        weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number
            of output channels.
        bias (Tensor or None): if not None, bias term for the convolution.
        stride (int): stride of convolution.
        padding (int): padding to apply to the input.
        block_ratio (float): can be tuned for speed. The input is splitted in chunks
            with a size of `int(block_ratio * kernel_size)`.

    Shape:

        - Inputs: `input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`.
        - Output: `(*, T)`


    ..note::
        This function is faster than `torch.nn.functional.conv1d` only in specific cases.
        Typically, the kernel size should be of the order of 256 to see any real gain,
        for a stride of 1.

    ..Warning::
        Dilation and groups are not supported at the moment. This function might use
        more memory than the default Conv1d implementation.
    """
    input = F.pad(input, (padding, padding))
    batch, channels, length = input.shape
    out_channels, _, kernel_size = weight.shape

    if length < kernel_size:
        raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, "
                           f"but it is only {length} samples long.")
    if block_ratio < 1:
        raise RuntimeError("Block ratio must be greater than 1.")

    # We are going to process the input blocks by blocks, as for some reason it is faster
    # and less memory intensive (I think the culprit is `torch.einsum`.
    block_size: int = min(int(kernel_size * block_ratio), length)
    fold_stride = block_size - kernel_size + 1
    weight = pad_to(weight, block_size)
    weight_z = _rfft(weight)

    # We pad the input and get the different frames, on which
    frames = unfold(input, block_size, fold_stride)

    frames_z = _rfft(frames)
    out_z = _compl_mul_conjugate(frames_z, weight_z)
    out = _irfft(out_z, block_size)
    # The last bit is invalid, because FFT will do a circular convolution.
    out = out[..., :-kernel_size + 1]
    out = out.reshape(batch, out_channels, -1)
    out = out[..., ::stride]
    target_length = (length - kernel_size) // stride + 1
    out = out[..., :target_length]
    if bias is not None:
        out += bias[:, None]
    return out


class FFTConv1d(torch.nn.Module):
    """
    Same as `torch.nn.Conv1d` but based on `fft_conv1d`.
    Please check PyTorch documentation for more information.

    Args:
        in_channels (int): number of input channels.
        out_channels (int): number of output channels.
        kernel_size (int): kernel size of convolution.
        stride (int): stride of convolution.
        padding (int): padding to apply to the input.
        bias (bool): if True, use a bias term.

    ..note::
        This module is faster than `torch.nn.Conv1d` only in specific cases.
        Typically, `kernel_size` should be of the order of 256 to see any real gain,
        for a stride of 1.

    ..warning::
        Dilation and groups are not supported at the moment. This module might use
        more memory than the default Conv1d implementation.

    >>> fftconv = FFTConv1d(12, 24, 128, 4)
    >>> x = torch.randn(4, 12, 1024)
    >>> print(list(fftconv(x).shape))
    [4, 24, 225]
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
                 stride: int = 1, padding: int = 0, bias: bool = True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias)
        self.weight = conv.weight
        self.bias = conv.bias

    def forward(self, input: torch.Tensor):
        return fft_conv1d(
            input, self.weight, self.bias, self.stride, self.padding)

    def __repr__(self):
        return simple_repr(self, overrides={"bias": self.bias is not None})