Spaces:
Running
Running
Commit
·
2dc6a6a
1
Parent(s):
5146d8c
Add toJSON method to FX classes for JSON serialization of parameters
Browse files- modules/fx.py +63 -1
modules/fx.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch.nn.utils.parametrize import register_parametrization
|
5 |
-
from torchcomp import ms2coef, coef2ms, db2amp
|
6 |
from torchaudio.transforms import Spectrogram, InverseSpectrogram
|
7 |
|
8 |
from typing import List, Tuple, Union, Any, Optional, Callable
|
@@ -72,6 +72,9 @@ class FX(nn.Module):
|
|
72 |
|
73 |
self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
|
74 |
|
|
|
|
|
|
|
75 |
|
76 |
class SmoothingCoef(nn.Module):
|
77 |
def forward(self, x):
|
@@ -196,6 +199,18 @@ class CompressorExpander(FX):
|
|
196 |
s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
|
197 |
return s
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
def forward(self, x):
|
200 |
if self.lookahead:
|
201 |
lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
|
@@ -230,6 +245,11 @@ class Panning(FX):
|
|
230 |
s = f"pan: {self.params.pan.item() * 200 - 100}"
|
231 |
return s
|
232 |
|
|
|
|
|
|
|
|
|
|
|
233 |
def forward(self, x: torch.Tensor):
|
234 |
angle = self.params.pan.view(1) * torch.pi * 0.5
|
235 |
amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
|
@@ -310,6 +330,12 @@ class LowPass(FX):
|
|
310 |
s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
|
311 |
return s
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
class HighPass(LowPass):
|
315 |
def __init__(
|
@@ -363,6 +389,13 @@ class Peak(FX):
|
|
363 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
|
364 |
return s
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
class LowShelf(FX):
|
368 |
def __init__(
|
@@ -394,6 +427,12 @@ class LowShelf(FX):
|
|
394 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
|
395 |
return s
|
396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
class HighShelf(LowShelf):
|
399 |
def __init__(
|
@@ -611,6 +650,15 @@ class Delay(FX):
|
|
611 |
)
|
612 |
return s
|
613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
|
615 |
class SurrogateDelay(Delay):
|
616 |
def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
|
@@ -992,3 +1040,17 @@ class FDN(FX):
|
|
992 |
F.pad(x, (self.ir_length - 1, 0)),
|
993 |
h.flip(-1),
|
994 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
from torch.nn.utils.parametrize import register_parametrization
|
5 |
+
from torchcomp import ms2coef, coef2ms, db2amp, amp2db
|
6 |
from torchaudio.transforms import Spectrogram, InverseSpectrogram
|
7 |
|
8 |
from typing import List, Tuple, Union, Any, Optional, Callable
|
|
|
72 |
|
73 |
self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
|
74 |
|
75 |
+
def toJSON(self) -> dict[str, Any]:
|
76 |
+
return {k: v.item() for k, v in self.params.items() if v.numel() == 1}
|
77 |
+
|
78 |
|
79 |
class SmoothingCoef(nn.Module):
|
80 |
def forward(self, x):
|
|
|
199 |
s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
|
200 |
return s
|
201 |
|
202 |
+
def toJSON(self) -> dict[str, Any]:
|
203 |
+
return {
|
204 |
+
"Attack (ms)": coef2ms(self.params.at, self.sr).item(),
|
205 |
+
"Release (ms)": coef2ms(self.params.rt, self.sr).item(),
|
206 |
+
"Average Coefficient": self.params.avg_coef.item(),
|
207 |
+
"Compressor Ratio": self.params.cmp_ratio.item(),
|
208 |
+
"Expander Ratio": self.params.exp_ratio.item(),
|
209 |
+
"Compressor Threshold (dB)": self.params.cmp_th.item(),
|
210 |
+
"Expander Threshold (dB)": self.params.exp_th.item(),
|
211 |
+
"Make Up (dB)": self.params.make_up.item(),
|
212 |
+
} | ({"Lookahead (ms)": self.params.lookahead.item()} if self.lookahead else {})
|
213 |
+
|
214 |
def forward(self, x):
|
215 |
if self.lookahead:
|
216 |
lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
|
|
|
245 |
s = f"pan: {self.params.pan.item() * 200 - 100}"
|
246 |
return s
|
247 |
|
248 |
+
def toJSON(self) -> dict[str, Any]:
|
249 |
+
return {
|
250 |
+
"Pan": self.params.pan.item() * 200 - 100,
|
251 |
+
}
|
252 |
+
|
253 |
def forward(self, x: torch.Tensor):
|
254 |
angle = self.params.pan.view(1) * torch.pi * 0.5
|
255 |
amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
|
|
|
330 |
s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
|
331 |
return s
|
332 |
|
333 |
+
def toJSON(self) -> dict[str, Any]:
|
334 |
+
return {
|
335 |
+
"Frequency (Hz)": self.params.freq.item(),
|
336 |
+
"Q": self.params.Q.item(),
|
337 |
+
}
|
338 |
+
|
339 |
|
340 |
class HighPass(LowPass):
|
341 |
def __init__(
|
|
|
389 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
|
390 |
return s
|
391 |
|
392 |
+
def toJSON(self) -> dict[str, Any]:
|
393 |
+
return {
|
394 |
+
"Frequency (Hz)": self.params.freq.item(),
|
395 |
+
"Gain (dB)": self.params.gain.item(),
|
396 |
+
"Q": self.params.Q.item(),
|
397 |
+
}
|
398 |
+
|
399 |
|
400 |
class LowShelf(FX):
|
401 |
def __init__(
|
|
|
427 |
s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
|
428 |
return s
|
429 |
|
430 |
+
def toJSON(self) -> dict[str, Any]:
|
431 |
+
return {
|
432 |
+
"Frequency (Hz)": self.params.freq.item(),
|
433 |
+
"Gain (dB)": self.params.gain.item(),
|
434 |
+
}
|
435 |
+
|
436 |
|
437 |
class HighShelf(LowShelf):
|
438 |
def __init__(
|
|
|
650 |
)
|
651 |
return s
|
652 |
|
653 |
+
def toJSON(self) -> dict[str, Any]:
|
654 |
+
return {
|
655 |
+
"Delay (ms)": self.params.delay.item(),
|
656 |
+
"Feedback (dB)": self.params.feedback.log10().mul(20).item(),
|
657 |
+
"Gain (dB)": self.params.gain.log10().mul(20).item(),
|
658 |
+
"Odd delays": self.odd_pan.toJSON(),
|
659 |
+
"Even delays": self.even_pan.toJSON(),
|
660 |
+
}
|
661 |
+
|
662 |
|
663 |
class SurrogateDelay(Delay):
|
664 |
def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
|
|
|
1040 |
F.pad(x, (self.ir_length - 1, 0)),
|
1041 |
h.flip(-1),
|
1042 |
)
|
1043 |
+
|
1044 |
+
def toJSON(self) -> dict[str, Any]:
|
1045 |
+
return {
|
1046 |
+
"T60 (s)": {
|
1047 |
+
f"{f:.2f} Hz": g.item()
|
1048 |
+
for f, g in zip(
|
1049 |
+
torch.linspace(0, 22050, self.params.gamma.numel()),
|
1050 |
+
-60 * self.delays.min() / amp2db(self.params.gamma) / 44100,
|
1051 |
+
)
|
1052 |
+
},
|
1053 |
+
"Gain (dB, approx)": amp2db(
|
1054 |
+
torch.linalg.norm(self.params.b) * torch.linalg.norm(self.params.c)
|
1055 |
+
).item(),
|
1056 |
+
}
|