Spaces:
Running
Running
Commit
·
2dc6a6a
1
Parent(s):
5146d8c
Add toJSON method to FX classes for JSON serialization of parameters
Browse files- modules/fx.py +63 -1
modules/fx.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch.nn.utils.parametrize import register_parametrization
|
| 5 |
-
from torchcomp import ms2coef, coef2ms, db2amp
|
| 6 |
from torchaudio.transforms import Spectrogram, InverseSpectrogram
|
| 7 |
|
| 8 |
from typing import List, Tuple, Union, Any, Optional, Callable
|
|
@@ -72,6 +72,9 @@ class FX(nn.Module):
|
|
| 72 |
|
| 73 |
self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
class SmoothingCoef(nn.Module):
|
| 77 |
def forward(self, x):
|
|
@@ -196,6 +199,18 @@ class CompressorExpander(FX):
|
|
| 196 |
s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
|
| 197 |
return s
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
def forward(self, x):
|
| 200 |
if self.lookahead:
|
| 201 |
lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
|
|
@@ -230,6 +245,11 @@ class Panning(FX):
|
|
| 230 |
s = f"pan: {self.params.pan.item() * 200 - 100}"
|
| 231 |
return s
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def forward(self, x: torch.Tensor):
|
| 234 |
angle = self.params.pan.view(1) * torch.pi * 0.5
|
| 235 |
amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
|
|
@@ -310,6 +330,12 @@ class LowPass(FX):
|
|
| 310 |
s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
|
| 311 |
return s
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
class HighPass(LowPass):
|
| 315 |
def __init__(
|
|
@@ -363,6 +389,13 @@ class Peak(FX):
|
|
| 363 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
|
| 364 |
return s
|
| 365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
class LowShelf(FX):
|
| 368 |
def __init__(
|
|
@@ -394,6 +427,12 @@ class LowShelf(FX):
|
|
| 394 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
|
| 395 |
return s
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
class HighShelf(LowShelf):
|
| 399 |
def __init__(
|
|
@@ -611,6 +650,15 @@ class Delay(FX):
|
|
| 611 |
)
|
| 612 |
return s
|
| 613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
class SurrogateDelay(Delay):
|
| 616 |
def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
|
|
@@ -992,3 +1040,17 @@ class FDN(FX):
|
|
| 992 |
F.pad(x, (self.ir_length - 1, 0)),
|
| 993 |
h.flip(-1),
|
| 994 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch.nn.utils.parametrize import register_parametrization
|
| 5 |
+
from torchcomp import ms2coef, coef2ms, db2amp, amp2db
|
| 6 |
from torchaudio.transforms import Spectrogram, InverseSpectrogram
|
| 7 |
|
| 8 |
from typing import List, Tuple, Union, Any, Optional, Callable
|
|
|
|
| 72 |
|
| 73 |
self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
|
| 74 |
|
| 75 |
+
def toJSON(self) -> dict[str, Any]:
|
| 76 |
+
return {k: v.item() for k, v in self.params.items() if v.numel() == 1}
|
| 77 |
+
|
| 78 |
|
| 79 |
class SmoothingCoef(nn.Module):
|
| 80 |
def forward(self, x):
|
|
|
|
| 199 |
s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
|
| 200 |
return s
|
| 201 |
|
| 202 |
+
def toJSON(self) -> dict[str, Any]:
|
| 203 |
+
return {
|
| 204 |
+
"Attack (ms)": coef2ms(self.params.at, self.sr).item(),
|
| 205 |
+
"Release (ms)": coef2ms(self.params.rt, self.sr).item(),
|
| 206 |
+
"Average Coefficient": self.params.avg_coef.item(),
|
| 207 |
+
"Compressor Ratio": self.params.cmp_ratio.item(),
|
| 208 |
+
"Expander Ratio": self.params.exp_ratio.item(),
|
| 209 |
+
"Compressor Threshold (dB)": self.params.cmp_th.item(),
|
| 210 |
+
"Expander Threshold (dB)": self.params.exp_th.item(),
|
| 211 |
+
"Make Up (dB)": self.params.make_up.item(),
|
| 212 |
+
} | ({"Lookahead (ms)": self.params.lookahead.item()} if self.lookahead else {})
|
| 213 |
+
|
| 214 |
def forward(self, x):
|
| 215 |
if self.lookahead:
|
| 216 |
lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
|
|
|
|
| 245 |
s = f"pan: {self.params.pan.item() * 200 - 100}"
|
| 246 |
return s
|
| 247 |
|
| 248 |
+
def toJSON(self) -> dict[str, Any]:
|
| 249 |
+
return {
|
| 250 |
+
"Pan": self.params.pan.item() * 200 - 100,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
def forward(self, x: torch.Tensor):
|
| 254 |
angle = self.params.pan.view(1) * torch.pi * 0.5
|
| 255 |
amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
|
|
|
|
| 330 |
s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
|
| 331 |
return s
|
| 332 |
|
| 333 |
+
def toJSON(self) -> dict[str, Any]:
|
| 334 |
+
return {
|
| 335 |
+
"Frequency (Hz)": self.params.freq.item(),
|
| 336 |
+
"Q": self.params.Q.item(),
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
|
| 340 |
class HighPass(LowPass):
|
| 341 |
def __init__(
|
|
|
|
| 389 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
|
| 390 |
return s
|
| 391 |
|
| 392 |
+
def toJSON(self) -> dict[str, Any]:
|
| 393 |
+
return {
|
| 394 |
+
"Frequency (Hz)": self.params.freq.item(),
|
| 395 |
+
"Gain (dB)": self.params.gain.item(),
|
| 396 |
+
"Q": self.params.Q.item(),
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
|
| 400 |
class LowShelf(FX):
|
| 401 |
def __init__(
|
|
|
|
| 427 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
|
| 428 |
return s
|
| 429 |
|
| 430 |
+
def toJSON(self) -> dict[str, Any]:
|
| 431 |
+
return {
|
| 432 |
+
"Frequency (Hz)": self.params.freq.item(),
|
| 433 |
+
"Gain (dB)": self.params.gain.item(),
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
|
| 437 |
class HighShelf(LowShelf):
|
| 438 |
def __init__(
|
|
|
|
| 650 |
)
|
| 651 |
return s
|
| 652 |
|
| 653 |
+
def toJSON(self) -> dict[str, Any]:
|
| 654 |
+
return {
|
| 655 |
+
"Delay (ms)": self.params.delay.item(),
|
| 656 |
+
"Feedback (dB)": self.params.feedback.log10().mul(20).item(),
|
| 657 |
+
"Gain (dB)": self.params.gain.log10().mul(20).item(),
|
| 658 |
+
"Odd delays": self.odd_pan.toJSON(),
|
| 659 |
+
"Even delays": self.even_pan.toJSON(),
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
|
| 663 |
class SurrogateDelay(Delay):
|
| 664 |
def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
|
|
|
|
| 1040 |
F.pad(x, (self.ir_length - 1, 0)),
|
| 1041 |
h.flip(-1),
|
| 1042 |
)
|
| 1043 |
+
|
| 1044 |
+
def toJSON(self) -> dict[str, Any]:
|
| 1045 |
+
return {
|
| 1046 |
+
"T60 (s)": {
|
| 1047 |
+
f"{f:.2f} Hz": g.item()
|
| 1048 |
+
for f, g in zip(
|
| 1049 |
+
torch.linspace(0, 22050, self.params.gamma.numel()),
|
| 1050 |
+
-60 * self.delays.min() / amp2db(self.params.gamma) / 44100,
|
| 1051 |
+
)
|
| 1052 |
+
},
|
| 1053 |
+
"Gain (dB, approx)": amp2db(
|
| 1054 |
+
torch.linalg.norm(self.params.b) * torch.linalg.norm(self.params.c)
|
| 1055 |
+
).item(),
|
| 1056 |
+
}
|