| |
| |
|
|
| """ |
| Implementation of a FFT based 1D convolution in PyTorch. |
| While FFT is used in CUDNN for small kernel sizes, it is not the case for long ones, e.g. 512. |
| This module implements efficient FFT based convolutions for such convolutions. A typical |
| application is for evaluationg FIR filters with a long receptive field, typically |
| evaluated with a stride of 1. |
| """ |
| from typing import Optional |
|
|
| import torch |
| try: |
| import torch.fft as new_fft |
| except ImportError: |
| new_fft = None |
| from torch.nn import functional as F |
|
|
| from .core import pad_to, unfold |
| from .utils import simple_repr |
|
|
|
|
| |
| def _new_rfft(x: torch.Tensor): |
| z = new_fft.rfft(x, dim=-1) |
| return torch.view_as_real(z) |
|
|
|
|
| def _old_rfft(x: torch.Tensor): |
| return torch.rfft(x, 1) |
|
|
|
|
| def _old_irfft(x: torch.Tensor, length: int): |
| result = torch.irfft(x, 1, signal_sizes=(length,)) |
| return result |
|
|
|
|
| def _new_irfft(x: torch.Tensor, length: int): |
| x = torch.view_as_complex(x) |
| return new_fft.irfft(x, length, dim=-1) |
|
|
|
|
| if new_fft is None: |
| _rfft = _old_rfft |
| _irfft = _old_irfft |
| else: |
| _rfft = _new_rfft |
| _irfft = _new_irfft |
|
|
|
|
| def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor): |
| """ |
| Given a and b two tensors of dimension 4 |
| with the last dimension being the real and imaginary part, |
| returns a multiplied by the conjugate of b, the multiplication |
| being with respect to the second dimension. |
| |
| """ |
| |
| |
|
|
| op = "bcft,dct->bdft" |
| return torch.stack([ |
| torch.einsum(op, a[..., 0], b[..., 0]) + torch.einsum(op, a[..., 1], b[..., 1]), |
| torch.einsum(op, a[..., 1], b[..., 0]) - torch.einsum(op, a[..., 0], b[..., 1]) |
| ], |
| dim=-1) |
|
|
|
|
| def fft_conv1d( |
| input: torch.Tensor, weight: torch.Tensor, |
| bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0, |
| block_ratio: float = 5): |
| """ |
| Same as `torch.nn.functional.conv1d` but using FFT for the convolution. |
| Please check PyTorch documentation for more information. |
| |
| Args: |
| input (Tensor): input signal of shape `[B, C, T]`. |
| weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number |
| of output channels. |
| bias (Tensor or None): if not None, bias term for the convolution. |
| stride (int): stride of convolution. |
| padding (int): padding to apply to the input. |
| block_ratio (float): can be tuned for speed. The input is splitted in chunks |
| with a size of `int(block_ratio * kernel_size)`. |
| |
| Shape: |
| |
| - Inputs: `input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. |
| - Output: `(*, T)` |
| |
| |
| ..note:: |
| This function is faster than `torch.nn.functional.conv1d` only in specific cases. |
| Typically, the kernel size should be of the order of 256 to see any real gain, |
| for a stride of 1. |
| |
| ..Warning:: |
| Dilation and groups are not supported at the moment. This function might use |
| more memory than the default Conv1d implementation. |
| """ |
| input = F.pad(input, (padding, padding)) |
| batch, channels, length = input.shape |
| out_channels, _, kernel_size = weight.shape |
|
|
| if length < kernel_size: |
| raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, " |
| f"but it is only {length} samples long.") |
| if block_ratio < 1: |
| raise RuntimeError("Block ratio must be greater than 1.") |
|
|
| |
| |
| block_size: int = min(int(kernel_size * block_ratio), length) |
| fold_stride = block_size - kernel_size + 1 |
| weight = pad_to(weight, block_size) |
| weight_z = _rfft(weight) |
|
|
| |
| frames = unfold(input, block_size, fold_stride) |
|
|
| frames_z = _rfft(frames) |
| out_z = _compl_mul_conjugate(frames_z, weight_z) |
| out = _irfft(out_z, block_size) |
| |
| out = out[..., :-kernel_size + 1] |
| out = out.reshape(batch, out_channels, -1) |
| out = out[..., ::stride] |
| target_length = (length - kernel_size) // stride + 1 |
| out = out[..., :target_length] |
| if bias is not None: |
| out += bias[:, None] |
| return out |
|
|
|
|
| class FFTConv1d(torch.nn.Module): |
| """ |
| Same as `torch.nn.Conv1d` but based on `fft_conv1d`. |
| Please check PyTorch documentation for more information. |
| |
| Args: |
| in_channels (int): number of input channels. |
| out_channels (int): number of output channels. |
| kernel_size (int): kernel size of convolution. |
| stride (int): stride of convolution. |
| padding (int): padding to apply to the input. |
| bias (bool): if True, use a bias term. |
| |
| ..note:: |
| This module is faster than `torch.nn.Conv1d` only in specific cases. |
| Typically, `kernel_size` should be of the order of 256 to see any real gain, |
| for a stride of 1. |
| |
| ..warning:: |
| Dilation and groups are not supported at the moment. This module might use |
| more memory than the default Conv1d implementation. |
| |
| >>> fftconv = FFTConv1d(12, 24, 128, 4) |
| >>> x = torch.randn(4, 12, 1024) |
| >>> print(list(fftconv(x).shape)) |
| [4, 24, 225] |
| """ |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int, |
| stride: int = 1, padding: int = 0, bias: bool = True): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
|
|
| conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) |
| self.weight = conv.weight |
| self.bias = conv.bias |
|
|
| def forward(self, input: torch.Tensor): |
| return fft_conv1d( |
| input, self.weight, self.bias, self.stride, self.padding) |
|
|
| def __repr__(self): |
| return simple_repr(self, overrides={"bias": self.bias is not None}) |
|
|