|
|
|
|
|
""" |
|
Differentiable, Pytorch based resampling. |
|
Implementation of Julius O. Smith algorithm for resampling. |
|
See https://ccrma.stanford.edu/~jos/resample/ for details. |
|
This implementation is specially optimized for when new_sr / old_sr is a fraction |
|
with a small numerator and denominator when removing the gcd (e.g. new_sr = 700, old_sr = 500). |
|
|
|
Very similar to [bmcfee/resampy](https://github.com/bmcfee/resampy) except this implementation |
|
is optimized for the case mentioned before, while resampy is slower but more general. |
|
|
|
""" |
|
|
|
import math |
|
from typing import Optional |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from .core import sinc |
|
from .utils import simple_repr |
|
|
|
|
|
class ResampleFrac(torch.nn.Module): |
|
""" |
|
Resampling from the sample rate `old_sr` to `new_sr`. |
|
""" |
|
def __init__(self, old_sr: int, new_sr: int, zeros: int = 24, rolloff: float = 0.945): |
|
""" |
|
Args: |
|
old_sr (int): sample rate of the input signal x. |
|
new_sr (int): sample rate of the output. |
|
zeros (int): number of zero crossing to keep in the sinc filter. |
|
rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, |
|
to ensure sufficient margin due to the imperfection of the FIR filter used. |
|
Lowering this value will reduce anti-aliasing, but will reduce some of the |
|
highest frequencies. |
|
|
|
Shape: |
|
|
|
- Input: `[*, T]` |
|
- Output: `[*, T']` with `T' = int(new_sr * T / old_sr) |
|
|
|
|
|
.. caution:: |
|
After dividing `old_sr` and `new_sr` by their GCD, both should be small |
|
for this implementation to be fast. |
|
|
|
>>> import torch |
|
>>> resample = ResampleFrac(4, 5) |
|
>>> x = torch.randn(1000) |
|
>>> print(len(resample(x))) |
|
1250 |
|
""" |
|
super().__init__() |
|
if not isinstance(old_sr, int) or not isinstance(new_sr, int): |
|
raise ValueError("old_sr and new_sr should be integers") |
|
gcd = math.gcd(old_sr, new_sr) |
|
self.old_sr = old_sr // gcd |
|
self.new_sr = new_sr // gcd |
|
self.zeros = zeros |
|
self.rolloff = rolloff |
|
|
|
self._init_kernels() |
|
|
|
def _init_kernels(self): |
|
if self.old_sr == self.new_sr: |
|
return |
|
|
|
kernels = [] |
|
sr = min(self.new_sr, self.old_sr) |
|
|
|
|
|
|
|
|
|
sr *= self.rolloff |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._width = math.ceil(self.zeros * self.old_sr / sr) |
|
|
|
|
|
|
|
|
|
idx = torch.arange(-self._width, self._width + self.old_sr).float() |
|
for i in range(self.new_sr): |
|
t = (-i/self.new_sr + idx/self.old_sr) * sr |
|
t = t.clamp_(-self.zeros, self.zeros) |
|
t *= math.pi |
|
window = torch.cos(t/self.zeros/2)**2 |
|
kernel = sinc(t) * window |
|
|
|
kernel.div_(kernel.sum()) |
|
kernels.append(kernel) |
|
|
|
self.register_buffer("kernel", torch.stack(kernels).view(self.new_sr, 1, -1)) |
|
|
|
def forward(self, x: torch.Tensor, output_length: Optional[int] = None, full: bool = False): |
|
""" |
|
Resample x. |
|
Args: |
|
x (Tensor): signal to resample, time should be the last dimension |
|
output_length (None or int): This can be set to the desired output length |
|
(last dimension). Allowed values are between 0 and |
|
ceil(length * new_sr / old_sr). When None (default) is specified, the |
|
floored output length will be used. In order to select the largest possible |
|
size, use the `full` argument. |
|
full (bool): return the longest possible output from the input. This can be useful |
|
if you chain resampling operations, and want to give the `output_length` only |
|
for the last one, while passing `full=True` to all the other ones. |
|
""" |
|
if self.old_sr == self.new_sr: |
|
return x |
|
shape = x.shape |
|
length = x.shape[-1] |
|
x = x.reshape(-1, length) |
|
x = F.pad(x[:, None], (self._width, self._width + self.old_sr), mode='replicate') |
|
ys = F.conv1d(x, self.kernel, stride=self.old_sr) |
|
y = ys.transpose(1, 2).reshape(list(shape[:-1]) + [-1]) |
|
|
|
float_output_length = self.new_sr * length / self.old_sr |
|
max_output_length = int(math.ceil(float_output_length)) |
|
default_output_length = int(float_output_length) |
|
if output_length is None: |
|
output_length = max_output_length if full else default_output_length |
|
elif output_length < 0 or output_length > max_output_length: |
|
raise ValueError(f"output_length must be between 0 and {max_output_length}") |
|
else: |
|
if full: |
|
raise ValueError("You cannot pass both full=True and output_length") |
|
return y[..., :output_length] |
|
|
|
def __repr__(self): |
|
return simple_repr(self) |
|
|
|
|
|
def resample_frac(x: torch.Tensor, old_sr: int, new_sr: int, |
|
zeros: int = 24, rolloff: float = 0.945, |
|
output_length: Optional[int] = None, full: bool = False): |
|
""" |
|
Functional version of `ResampleFrac`, refer to its documentation for more information. |
|
|
|
..warning:: |
|
If you call repeatidly this functions with the same sample rates, then the |
|
resampling kernel will be recomputed everytime. For best performance, you should use |
|
and cache an instance of `ResampleFrac`. |
|
""" |
|
return ResampleFrac(old_sr, new_sr, zeros, rolloff).to(x)(x, output_length, full) |
|
|
|
|
|
|
|
|
|
|
|
def _kernel_upsample2_downsample2(zeros): |
|
|
|
|
|
win = torch.hann_window(4 * zeros + 1, periodic=False) |
|
winodd = win[1::2] |
|
t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) |
|
t *= math.pi |
|
kernel = (sinc(t) * winodd).view(1, 1, -1) |
|
return kernel |
|
|
|
|
|
def _upsample2(x, zeros=24): |
|
""" |
|
Upsample x by a factor of two. The output will be exactly twice as long as the input. |
|
Args: |
|
x (Tensor): signal to upsample, time should be the last dimension |
|
zeros (int): number of zero crossing to keep in the sinc filter. |
|
|
|
This function is kept only for reference, you should use the more generic `resample_frac` |
|
one. This function does not perform anti-aliasing filtering. |
|
""" |
|
*other, time = x.shape |
|
kernel = _kernel_upsample2_downsample2(zeros).to(x) |
|
out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) |
|
y = torch.stack([x, out], dim=-1) |
|
return y.view(*other, -1) |
|
|
|
|
|
def _downsample2(x, zeros=24): |
|
""" |
|
Downsample x by a factor of two. The output length is half of the input, ceiled. |
|
Args: |
|
x (Tensor): signal to downsample, time should be the last dimension |
|
zeros (int): number of zero crossing to keep in the sinc filter. |
|
|
|
This function is kept only for reference, you should use the more generic `resample_frac` |
|
one. This function does not perform anti-aliasing filtering. |
|
""" |
|
if x.shape[-1] % 2 != 0: |
|
x = F.pad(x, (0, 1)) |
|
xeven = x[..., ::2] |
|
xodd = x[..., 1::2] |
|
*other, time = xodd.shape |
|
kernel = _kernel_upsample2_downsample2(zeros).to(x) |
|
out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view( |
|
*other, time) |
|
return out.view(*other, -1).mul(0.5) |
|
|