Spaces:
Running
Running
| # File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. | |
| # Author: adefossez, 2020 | |
| """ | |
| Differentiable, Pytorch based resampling. | |
| Implementation of Julius O. Smith algorithm for resampling. | |
| See https://ccrma.stanford.edu/~jos/resample/ for details. | |
| This implementation is specially optimized for when new_sr / old_sr is a fraction | |
| with a small numerator and denominator when removing the gcd (e.g. new_sr = 700, old_sr = 500). | |
| Very similar to [bmcfee/resampy](https://github.com/bmcfee/resampy) except this implementation | |
| is optimized for the case mentioned before, while resampy is slower but more general. | |
| """ | |
| import math | |
| from typing import Optional | |
| import torch | |
| from torch.nn import functional as F | |
| from .core import sinc | |
| from .utils import simple_repr | |
| class ResampleFrac(torch.nn.Module): | |
| """ | |
| Resampling from the sample rate `old_sr` to `new_sr`. | |
| """ | |
| def __init__(self, old_sr: int, new_sr: int, zeros: int = 24, rolloff: float = 0.945): | |
| """ | |
| Args: | |
| old_sr (int): sample rate of the input signal x. | |
| new_sr (int): sample rate of the output. | |
| zeros (int): number of zero crossing to keep in the sinc filter. | |
| rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, | |
| to ensure sufficient margin due to the imperfection of the FIR filter used. | |
| Lowering this value will reduce anti-aliasing, but will reduce some of the | |
| highest frequencies. | |
| Shape: | |
| - Input: `[*, T]` | |
| - Output: `[*, T']` with `T' = int(new_sr * T / old_sr) | |
| .. caution:: | |
| After dividing `old_sr` and `new_sr` by their GCD, both should be small | |
| for this implementation to be fast. | |
| >>> import torch | |
| >>> resample = ResampleFrac(4, 5) | |
| >>> x = torch.randn(1000) | |
| >>> print(len(resample(x))) | |
| 1250 | |
| """ | |
| super().__init__() | |
| if not isinstance(old_sr, int) or not isinstance(new_sr, int): | |
| raise ValueError("old_sr and new_sr should be integers") | |
| gcd = math.gcd(old_sr, new_sr) | |
| self.old_sr = old_sr // gcd | |
| self.new_sr = new_sr // gcd | |
| self.zeros = zeros | |
| self.rolloff = rolloff | |
| self._init_kernels() | |
| def _init_kernels(self): | |
| if self.old_sr == self.new_sr: | |
| return | |
| kernels = [] | |
| sr = min(self.new_sr, self.old_sr) | |
| # rolloff will perform antialiasing filtering by removing the highest frequencies. | |
| # At first I thought I only needed this when downsampling, but when upsampling | |
| # you will get edge artifacts without this, the edge is equivalent to zero padding, | |
| # which will add high freq artifacts. | |
| sr *= self.rolloff | |
| # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) | |
| # using the sinc interpolation formula: | |
| # x(t) = sum_i x[i] sinc(pi * old_sr * (i / old_sr - t)) | |
| # We can then sample the function x(t) with a different sample rate: | |
| # y[j] = x(j / new_sr) | |
| # or, | |
| # y[j] = sum_i x[i] sinc(pi * old_sr * (i / old_sr - j / new_sr)) | |
| # We see here that y[j] is the convolution of x[i] with a specific filter, for which | |
| # we take an FIR approximation, stopping when we see at least `zeros` zeros crossing. | |
| # But y[j+1] is going to have a different set of weights and so on, until y[j + new_sr]. | |
| # Indeed: | |
| # y[j + new_sr] = sum_i x[i] sinc(pi * old_sr * ((i / old_sr - (j + new_sr) / new_sr)) | |
| # = sum_i x[i] sinc(pi * old_sr * ((i - old_sr) / old_sr - j / new_sr)) | |
| # = sum_i x[i + old_sr] sinc(pi * old_sr * (i / old_sr - j / new_sr)) | |
| # so y[j+new_sr] uses the same filter as y[j], but on a shifted version of x by `old_sr`. | |
| # This will explain the F.conv1d after, with a stride of old_sr. | |
| self._width = math.ceil(self.zeros * self.old_sr / sr) | |
| # If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e., | |
| # they will have a lot of almost zero values to the left or to the right... | |
| # There is probably a way to evaluate those filters more efficiently, but this is kept for | |
| # future work. | |
| idx = torch.arange(-self._width, self._width + self.old_sr).float() | |
| for i in range(self.new_sr): | |
| t = (-i/self.new_sr + idx/self.old_sr) * sr | |
| t = t.clamp_(-self.zeros, self.zeros) | |
| t *= math.pi | |
| window = torch.cos(t/self.zeros/2)**2 | |
| kernel = sinc(t) * window | |
| # Renormalize kernel to ensure a constant signal is preserved. | |
| kernel.div_(kernel.sum()) | |
| kernels.append(kernel) | |
| self.register_buffer("kernel", torch.stack(kernels).view(self.new_sr, 1, -1)) | |
| def forward(self, x: torch.Tensor, output_length: Optional[int] = None, full: bool = False): | |
| """ | |
| Resample x. | |
| Args: | |
| x (Tensor): signal to resample, time should be the last dimension | |
| output_length (None or int): This can be set to the desired output length | |
| (last dimension). Allowed values are between 0 and | |
| ceil(length * new_sr / old_sr). When None (default) is specified, the | |
| floored output length will be used. In order to select the largest possible | |
| size, use the `full` argument. | |
| full (bool): return the longest possible output from the input. This can be useful | |
| if you chain resampling operations, and want to give the `output_length` only | |
| for the last one, while passing `full=True` to all the other ones. | |
| """ | |
| if self.old_sr == self.new_sr: | |
| return x | |
| shape = x.shape | |
| length = x.shape[-1] | |
| x = x.reshape(-1, length) | |
| x = F.pad(x[:, None], (self._width, self._width + self.old_sr), mode='replicate') | |
| ys = F.conv1d(x, self.kernel, stride=self.old_sr) # type: ignore | |
| y = ys.transpose(1, 2).reshape(list(shape[:-1]) + [-1]) | |
| float_output_length = self.new_sr * length / self.old_sr | |
| max_output_length = int(math.ceil(float_output_length)) | |
| default_output_length = int(float_output_length) | |
| if output_length is None: | |
| output_length = max_output_length if full else default_output_length | |
| elif output_length < 0 or output_length > max_output_length: | |
| raise ValueError(f"output_length must be between 0 and {max_output_length}") | |
| else: | |
| if full: | |
| raise ValueError("You cannot pass both full=True and output_length") | |
| return y[..., :output_length] | |
| def __repr__(self): | |
| return simple_repr(self) | |
| def resample_frac(x: torch.Tensor, old_sr: int, new_sr: int, | |
| zeros: int = 24, rolloff: float = 0.945, | |
| output_length: Optional[int] = None, full: bool = False): | |
| """ | |
| Functional version of `ResampleFrac`, refer to its documentation for more information. | |
| ..warning:: | |
| If you call repeatidly this functions with the same sample rates, then the | |
| resampling kernel will be recomputed everytime. For best performance, you should use | |
| and cache an instance of `ResampleFrac`. | |
| """ | |
| return ResampleFrac(old_sr, new_sr, zeros, rolloff).to(x)(x, output_length, full) | |
| # Easier implementations for downsampling and upsampling by a factor of 2 | |
| # Kept for testing and reference | |
| def _kernel_upsample2_downsample2(zeros): | |
| # Kernel for upsampling and downsampling by a factor of 2. Interestingly, | |
| # it is the same kernel used for both. | |
| win = torch.hann_window(4 * zeros + 1, periodic=False) | |
| winodd = win[1::2] | |
| t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) | |
| t *= math.pi | |
| kernel = (sinc(t) * winodd).view(1, 1, -1) | |
| return kernel | |
| def _upsample2(x, zeros=24): | |
| """ | |
| Upsample x by a factor of two. The output will be exactly twice as long as the input. | |
| Args: | |
| x (Tensor): signal to upsample, time should be the last dimension | |
| zeros (int): number of zero crossing to keep in the sinc filter. | |
| This function is kept only for reference, you should use the more generic `resample_frac` | |
| one. This function does not perform anti-aliasing filtering. | |
| """ | |
| *other, time = x.shape | |
| kernel = _kernel_upsample2_downsample2(zeros).to(x) | |
| out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) | |
| y = torch.stack([x, out], dim=-1) | |
| return y.view(*other, -1) | |
| def _downsample2(x, zeros=24): | |
| """ | |
| Downsample x by a factor of two. The output length is half of the input, ceiled. | |
| Args: | |
| x (Tensor): signal to downsample, time should be the last dimension | |
| zeros (int): number of zero crossing to keep in the sinc filter. | |
| This function is kept only for reference, you should use the more generic `resample_frac` | |
| one. This function does not perform anti-aliasing filtering. | |
| """ | |
| if x.shape[-1] % 2 != 0: | |
| x = F.pad(x, (0, 1)) | |
| xeven = x[..., ::2] | |
| xodd = x[..., 1::2] | |
| *other, time = xodd.shape | |
| kernel = _kernel_upsample2_downsample2(zeros).to(x) | |
| out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view( | |
| *other, time) | |
| return out.view(*other, -1).mul(0.5) | |