diffvox / modules /functional.py
yoyolicoris's picture
manually copy part of the diffvox source code
d737ecd
raw
history blame
6.1 kB
import torch
import torch.nn.functional as F
from torchcomp import compexp_gain, db2amp
from torchlpc import sample_wise_lpc
from typing import List, Tuple, Union, Any, Optional
import math
def inv_22(a, b, c, d):
return torch.stack([d, -b, -c, a]).view(2, 2) / (a * d - b * c)
def eig_22(a, b, c, d):
# https://croninprojects.org/Vince/Geodesy/FindingEigenvectors.pdf
T = a + d
D = a * d - b * c
half_T = T * 0.5
root = torch.sqrt(half_T * half_T - D) # + 0j)
L = torch.stack([half_T + root, half_T - root])
y = (L - a) / b
# y = c / L
V = torch.stack([torch.ones_like(y), y])
return L, V / V.abs().square().sum(0).sqrt()
def fir(x, b):
padded = F.pad(x.reshape(-1, 1, x.size(-1)), (b.size(0) - 1, 0))
return F.conv1d(padded, b.flip(0).view(1, 1, -1)).view(*x.shape)
def allpole(x: torch.Tensor, a: torch.Tensor):
h = x.reshape(-1, x.shape[-1])
return sample_wise_lpc(
h,
a.broadcast_to(h.shape + a.shape),
).reshape(*x.shape)
def biquad(x: torch.Tensor, b0, b1, b2, a0, a1, a2):
b0 = b0 / a0
b1 = b1 / a0
b2 = b2 / a0
a1 = a1 / a0
a2 = a2 / a0
beta1 = b1 - b0 * a1
beta2 = b2 - b0 * a2
tmp = a1.square() - 4 * a2
if tmp < 0:
pole = 0.5 * (-a1 + 1j * torch.sqrt(-tmp))
u = -1j * x[..., :-1]
h = sample_wise_lpc(
u.reshape(-1, u.shape[-1]),
-pole.broadcast_to(u.shape).reshape(-1, u.shape[-1], 1),
).reshape(*u.shape)
h = (
h.real * (beta1 * pole.real / pole.imag + beta2 / pole.imag)
- beta1 * h.imag
)
else:
L, V = eig_22(-a1, -a2, torch.ones_like(a1), torch.zeros_like(a1))
inv_V = inv_22(*V.view(-1))
C = torch.stack([beta1, beta2]) @ V
# project input to eigen space
h = x[..., :-1].unsqueeze(-2) * inv_V[:, :1]
L = L.unsqueeze(-1).broadcast_to(h.shape)
h = (
sample_wise_lpc(h.reshape(-1, h.shape[-1]), -L.reshape(-1, L.shape[-1], 1))
.reshape(*h.shape)
.transpose(-2, -1)
) @ C
tmp = b0 * x
y = torch.cat([tmp[..., :1], h + tmp[..., 1:]], -1)
return y
def highpass_biquad_coef(
sample_rate: int,
cutoff_freq: torch.Tensor,
Q: torch.Tensor,
):
w0 = 2 * torch.pi * cutoff_freq / sample_rate
alpha = torch.sin(w0) / 2.0 / Q
b0 = (1 + torch.cos(w0)) / 2
b1 = -1 - torch.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * torch.cos(w0)
a2 = 1 - alpha
return b0, b1, b2, a0, a1, a2
def apply_biquad(bq):
return lambda waveform, *args, **kwargs: biquad(waveform, *bq(*args, **kwargs))
highpass_biquad = apply_biquad(highpass_biquad_coef)
def lowpass_biquad_coef(
sample_rate: int,
cutoff_freq: torch.Tensor,
Q: torch.Tensor,
):
w0 = 2 * torch.pi * cutoff_freq / sample_rate
alpha = torch.sin(w0) / 2 / Q
b0 = (1 - torch.cos(w0)) / 2
b1 = 1 - torch.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * torch.cos(w0)
a2 = 1 - alpha
return b0, b1, b2, a0, a1, a2
def equalizer_biquad_coef(
sample_rate: int,
center_freq: torch.Tensor,
gain: torch.Tensor,
Q: torch.Tensor,
):
w0 = 2 * torch.pi * center_freq / sample_rate
A = torch.exp(gain / 40.0 * math.log(10))
alpha = torch.sin(w0) / 2 / Q
b0 = 1 + alpha * A
b1 = -2 * torch.cos(w0)
b2 = 1 - alpha * A
a0 = 1 + alpha / A
a1 = -2 * torch.cos(w0)
a2 = 1 - alpha / A
return b0, b1, b2, a0, a1, a2
def lowshelf_biquad_coef(
sample_rate: int,
cutoff_freq: torch.Tensor,
gain: torch.Tensor,
Q: torch.Tensor,
):
w0 = 2 * torch.pi * cutoff_freq / sample_rate
A = torch.exp(gain / 40.0 * math.log(10))
alpha = torch.sin(w0) / 2 / Q
cosw0 = torch.cos(w0)
sqrtA = torch.sqrt(A)
b0 = A * (A + 1 - (A - 1) * cosw0 + 2 * alpha * sqrtA)
b1 = 2 * A * (A - 1 - (A + 1) * cosw0)
b2 = A * (A + 1 - (A - 1) * cosw0 - 2 * alpha * sqrtA)
a0 = A + 1 + (A - 1) * cosw0 + 2 * alpha * sqrtA
a1 = -2 * (A - 1 + (A + 1) * cosw0)
a2 = A + 1 + (A - 1) * cosw0 - 2 * alpha * sqrtA
return b0, b1, b2, a0, a1, a2
def highshelf_biquad_coef(
sample_rate: int,
cutoff_freq: torch.Tensor,
gain: torch.Tensor,
Q: torch.Tensor,
):
w0 = 2 * torch.pi * cutoff_freq / sample_rate
A = torch.exp(gain / 40.0 * math.log(10))
alpha = torch.sin(w0) / 2 / Q
cosw0 = torch.cos(w0)
sqrtA = torch.sqrt(A)
b0 = A * (A + 1 + (A - 1) * cosw0 + 2 * alpha * sqrtA)
b1 = -2 * A * (A - 1 + (A + 1) * cosw0)
b2 = A * (A + 1 + (A - 1) * cosw0 - 2 * alpha * sqrtA)
a0 = A + 1 - (A - 1) * cosw0 + 2 * alpha * sqrtA
a1 = 2 * (A - 1 - (A + 1) * cosw0)
a2 = A + 1 - (A - 1) * cosw0 - 2 * alpha * sqrtA
return b0, b1, b2, a0, a1, a2
highpass_biquad = apply_biquad(highpass_biquad_coef)
lowpass_biquad = apply_biquad(lowpass_biquad_coef)
highshelf_biquad = apply_biquad(highshelf_biquad_coef)
lowshelf_biquad = apply_biquad(lowshelf_biquad_coef)
equalizer_biquad = apply_biquad(equalizer_biquad_coef)
def avg(rms: torch.Tensor, avg_coef: torch.Tensor):
assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1)
h = rms * avg_coef
return sample_wise_lpc(
h,
(avg_coef - 1).broadcast_to(h.shape).unsqueeze(-1),
)
def avg_rms(audio: torch.Tensor, avg_coef) -> torch.Tensor:
return avg(audio.square().clamp_min(1e-8), avg_coef).sqrt()
def compressor_expander(
x: torch.Tensor,
avg_coef: Union[torch.Tensor, float],
cmp_th: Union[torch.Tensor, float],
cmp_ratio: Union[torch.Tensor, float],
exp_th: Union[torch.Tensor, float],
exp_ratio: Union[torch.Tensor, float],
at: Union[torch.Tensor, float],
rt: Union[torch.Tensor, float],
make_up: torch.Tensor,
lookahead_func=lambda x: x,
):
rms = avg_rms(x, avg_coef=avg_coef)
gain = compexp_gain(rms, cmp_th, cmp_ratio, exp_th, exp_ratio, at, rt)
gain = lookahead_func(gain)
return x * gain * db2amp(make_up).broadcast_to(x.shape[0], 1)