yourusername's picture
:beers: cheers
66a6dc0
import math
import torch
from typing import List
def butter(fc, fs: float = 2.0):
"""
Recall Butterworth polynomials
N = 1 s + 1
N = 2 s^2 + sqrt(2s) + 1
N = 3 (s^2 + s + 1)(s + 1)
N = 4 (s^2 + 0.76536s + 1)(s^2 + 1.84776s + 1)
Scaling
LP to LP: s -> s/w_c
LP to HP: s -> w_c/s
Bilinear transform:
s = 2/T_d * (1 - z^-1)/(1 + z^-1)
For 1-pole butterworth lowpass
1 / (s + 1) 1-pole prototype
1 / (s/w_c + 1) LP to LP
1 / (2/T_d * (1 - z^-1)/(1 + z^-1))/w_c + 1) Bilinear transform
"""
# apply pre-warping to the cutoff
T_d = 1 / fs
w_d = (2 * math.pi * fc) / fs
# sys.exit()
w_c = (2 / T_d) * torch.tan(w_d / 2)
a0 = 2 + (T_d * w_c)
a1 = (T_d * w_c) - 2
b0 = T_d * w_c
b1 = T_d * w_c
b = torch.stack([b0, b1], dim=0).view(-1)
a = torch.stack([a0, a1], dim=0).view(-1)
# normalize
b = b.type_as(fc) / a0
a = a.type_as(fc) / a0
return b, a
def biqaud(
gain_dB: torch.Tensor,
cutoff_freq: torch.Tensor,
q_factor: torch.Tensor,
sample_rate: float,
filter_type: str = "peaking",
):
# convert inputs to Tensors if needed
# gain_dB = torch.tensor([gain_dB])
# cutoff_freq = torch.tensor([cutoff_freq])
# q_factor = torch.tensor([q_factor])
A = 10 ** (gain_dB / 40.0)
w0 = 2 * math.pi * (cutoff_freq / sample_rate)
alpha = torch.sin(w0) / (2 * q_factor)
cos_w0 = torch.cos(w0)
sqrt_A = torch.sqrt(A)
if filter_type == "high_shelf":
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
elif filter_type == "low_shelf":
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
elif filter_type == "peaking":
b0 = 1 + alpha * A
b1 = -2 * cos_w0
b2 = 1 - alpha * A
a0 = 1 + (alpha / A)
a1 = -2 * cos_w0
a2 = 1 - (alpha / A)
else:
raise ValueError(f"Invalid filter_type: {filter_type}.")
b = torch.stack([b0, b1, b2], dim=0).view(-1)
a = torch.stack([a0, a1, a2], dim=0).view(-1)
# normalize
b = b.type_as(gain_dB) / a0
a = a.type_as(gain_dB) / a0
return b, a
def freqz(b, a, n_fft: int = 512):
B = torch.fft.rfft(b, n_fft)
A = torch.fft.rfft(a, n_fft)
H = B / A
return H
def freq_domain_filter(x, H, n_fft):
X = torch.fft.rfft(x, n_fft)
# move H to same device as input x
H = H.type_as(X)
Y = X * H
y = torch.fft.irfft(Y, n_fft)
return y
def approx_iir_filter(b, a, x):
"""Approimxate the application of an IIR filter.
Args:
b (Tensor): The numerator coefficients.
"""
# round up to nearest power of 2 for FFT
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
n_fft = n_fft.int()
# move coefficients to same device as x
b = b.type_as(x).view(-1)
a = a.type_as(x).view(-1)
# compute complex response
H = freqz(b, a, n_fft=n_fft).view(-1)
# apply filter
y = freq_domain_filter(x, H, n_fft)
# crop
y = y[: x.shape[-1]]
return y
def approx_iir_filter_cascade(
b_s: List[torch.Tensor],
a_s: List[torch.Tensor],
x: torch.Tensor,
):
"""Apply a cascade of IIR filters.
Args:
b (list[Tensor]): List of tensors of shape (3)
a (list[Tensor]): List of tensors of (3)
x (torch.Tensor): 1d Tensor.
"""
if len(b_s) != len(a_s):
raise RuntimeError(
f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}."
)
# round up to nearest power of 2 for FFT
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
n_fft = n_fft.int()
# this could be done in parallel
b = torch.stack(b_s, dim=0).type_as(x)
a = torch.stack(a_s, dim=0).type_as(x)
H = freqz(b, a, n_fft=n_fft)
H = torch.prod(H, dim=0).view(-1)
# apply filter
y = freq_domain_filter(x, H, n_fft)
# crop
y = y[: x.shape[-1]]
return y