File size: 4,802 Bytes
7bc29af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
# Author: adefossez, 2020
"""
Decomposition of a signal over frequency bands in the waveform domain.
"""
from typing import Optional, Sequence
import torch

from .core import mel_frequencies
from .lowpass import LowPassFilters
from .utils import simple_repr


class SplitBands(torch.nn.Module):
    """
    Decomposes a signal over the given frequency bands in the waveform domain using
    a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`.
    You can either specify explicitely the frequency cutoffs, or just the number of bands,
    in which case the frequency cutoffs will be spread out evenly in mel scale.

    Args:
        sample_rate (float): Sample rate of the input signal in Hz.
        n_bands (int or None): number of bands, when not giving them explictely with `cutoffs`.
            In that case, the cutoff frequencies will be evenly spaced in mel-space.
        cutoffs (list[float] or None): list of frequency cutoffs in Hz.
        pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
            the output will have the same length as the input.
        zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations.
        fft (bool or None): See `LowPassFilters` for more info.

    ..note::
        The sum of all the bands will always be the input signal.

    ..warning::
        Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along
        with the sample rate.

    Shape:

        - Input: `[*, T]`
        - Output: `[B, *, T']`, with `T'=T` if `pad` is True.
            If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1`

    >>> bands = SplitBands(sample_rate=128, n_bands=10)
    >>> x = torch.randn(6, 4, 1024)
    >>> list(bands(x).shape)
    [10, 6, 4, 1024]
    """

    def __init__(self, sample_rate: float, n_bands: Optional[int] = None,
                 cutoffs: Optional[Sequence[float]] = None, pad: bool = True,
                 zeros: float = 8, fft: Optional[bool] = None):
        super().__init__()
        if (cutoffs is None) + (n_bands is None) != 1:
            raise ValueError("You must provide either n_bands, or cutoffs, but not boths.")

        self.sample_rate = sample_rate
        self.n_bands = n_bands
        self._cutoffs = list(cutoffs) if cutoffs is not None else None
        self.pad = pad
        self.zeros = zeros
        self.fft = fft

        if cutoffs is None:
            if n_bands is None:
                raise ValueError("You must provide one of n_bands or cutoffs.")
            if not n_bands >= 1:
                raise ValueError(f"n_bands must be greater than one (got {n_bands})")
            cutoffs = mel_frequencies(n_bands + 1, 0, sample_rate / 2)[1:-1]
        else:
            if max(cutoffs) > 0.5 * sample_rate:
                raise ValueError("A cutoff above sample_rate/2 does not make sense.")
        if len(cutoffs) > 0:
            self.lowpass = LowPassFilters(
                [c / sample_rate for c in cutoffs], pad=pad, zeros=zeros, fft=fft)
        else:
            # Here I cannot make both TorchScript and MyPy happy.
            # I miss the good old times, before all this madness was created.
            self.lowpass = None  # type: ignore

    def forward(self, input):
        if self.lowpass is None:
            return input[None]
        lows = self.lowpass(input)
        low = lows[0]
        bands = [low]
        for low_and_band in lows[1:]:
            # Get a bandpass filter by substracting lowpasses
            band = low_and_band - low
            bands.append(band)
            low = low_and_band
        # Last band is whatever is left in the signal
        bands.append(input - low)
        return torch.stack(bands)

    @property
    def cutoffs(self):
        if self._cutoffs is not None:
            return self._cutoffs
        elif self.lowpass is not None:
            return [c * self.sample_rate for c in self.lowpass.cutoffs]
        else:
            return []

    def __repr__(self):
        return simple_repr(self, overrides={"cutoffs": self._cutoffs})


def split_bands(signal: torch.Tensor, sample_rate: float, n_bands: Optional[int] = None,
                cutoffs: Optional[Sequence[float]] = None, pad: bool = True,
                zeros: float = 8, fft: Optional[bool] = None):
    """
    Functional version of `SplitBands`, refer to this class for more information.

    >>> x = torch.randn(6, 4, 1024)
    >>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape)
    [3, 6, 4, 1024]
    """
    return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft).to(signal)(signal)