# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. # Author: adefossez, 2020 """ FIR windowed sinc lowpass filters. """ import math from typing import Sequence, Optional import torch from torch.nn import functional as F from .core import sinc from .fftconv import fft_conv1d from .utils import simple_repr class LowPassFilters(torch.nn.Module): """ Bank of low pass filters. Note that a high pass or band pass filter can easily be implemented by substracting a same signal processed with low pass filters with different frequencies (see `julius.bands.SplitBands` for instance). This uses a windowed sinc filter, very similar to the one used in `julius.resample`. However, because we do not change the sample rate here, this filter can be much more efficiently implemented using the FFT convolution from `julius.fftconv`. Args: cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where f_s is the samplerate and `f` is the cutoff frequency. The upper limit is 0.5, because a signal sampled at `f_s` contains only frequencies under `f_s / 2`. stride (int): how much to decimate the output. Keep in mind that decimation of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)` of the original sampling rate. pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, the output will have the same length as the input. zeros (float): Number of zero crossings to keep. Controls the receptive field of the Finite Impulse Response filter. For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, it is a bad idea to set this to a high value. This is likely appropriate for most use. Lower values will result in a faster filter, but with a slower attenuation around the cutoff frequency. fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. If False, uses PyTorch convolutions. If None, either one will be chosen automatically depending on the effective filter size. ..warning:: All the filters will use the same filter size, aligned on the lowest frequency provided. If you combine a lot of filters with very diverse frequencies, it might be more efficient to split them over multiple modules with similar frequencies. ..note:: A lowpass with a cutoff frequency of 0 is defined as the null function by convention here. This allows for a highpass with a cutoff of 0 to be equal to identity, as defined in `julius.filters.HighPassFilters`. Shape: - Input: `[*, T]` - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and `F` is the numer of cutoff frequencies. >>> lowpass = LowPassFilters([1/4]) >>> x = torch.randn(4, 12, 21, 1024) >>> list(lowpass(x).shape) [1, 4, 12, 21, 1024] """ def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None): super().__init__() self.cutoffs = list(cutoffs) if min(self.cutoffs) < 0: raise ValueError("Minimum cutoff must be larger than zero.") if max(self.cutoffs) > 0.5: raise ValueError("A cutoff above 0.5 does not make sense.") self.stride = stride self.pad = pad self.zeros = zeros self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2) if fft is None: fft = self.half_size > 32 self.fft = fft window = torch.hann_window(2 * self.half_size + 1, periodic=False) time = torch.arange(-self.half_size, self.half_size + 1) filters = [] for cutoff in cutoffs: if cutoff == 0: filter_ = torch.zeros_like(time) else: filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time) # Normalize filter to have sum = 1, otherwise we will have a small leakage # of the constant component in the input signal. filter_ /= filter_.sum() filters.append(filter_) self.register_buffer("filters", torch.stack(filters)[:, None]) def forward(self, input): shape = list(input.shape) input = input.view(-1, 1, shape[-1]) if self.pad: input = F.pad(input, (self.half_size, self.half_size), mode='replicate') if self.fft: out = fft_conv1d(input, self.filters, stride=self.stride) else: out = F.conv1d(input, self.filters, stride=self.stride) shape.insert(0, len(self.cutoffs)) shape[-1] = out.shape[-1] return out.permute(1, 0, 2).reshape(shape) def __repr__(self): return simple_repr(self) class LowPassFilter(torch.nn.Module): """ Same as `LowPassFilters` but applies a single low pass filter. Shape: - Input: `[*, T]` - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. >>> lowpass = LowPassFilter(1/4, stride=2) >>> x = torch.randn(4, 124) >>> list(lowpass(x).shape) [4, 62] """ def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None): super().__init__() self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) @property def cutoff(self): return self._lowpasses.cutoffs[0] @property def stride(self): return self._lowpasses.stride @property def pad(self): return self._lowpasses.pad @property def zeros(self): return self._lowpasses.zeros @property def fft(self): return self._lowpasses.fft def forward(self, input): return self._lowpasses(input)[0] def __repr__(self): return simple_repr(self) def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None): """ Functional version of `LowPassFilters`, refer to this class for more information. """ return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) def lowpass_filter(input: torch.Tensor, cutoff: float, stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None): """ Same as `lowpass_filters` but with a single cutoff frequency. Output will not have a dimension inserted in the front. """ return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0]