yoyolicoris commited on
Commit
2dc6a6a
·
1 Parent(s): 5146d8c

Add toJSON method to FX classes for JSON serialization of parameters

Browse files
Files changed (1) hide show
  1. modules/fx.py +63 -1
modules/fx.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from torch import nn
3
  import torch.nn.functional as F
4
  from torch.nn.utils.parametrize import register_parametrization
5
- from torchcomp import ms2coef, coef2ms, db2amp
6
  from torchaudio.transforms import Spectrogram, InverseSpectrogram
7
 
8
  from typing import List, Tuple, Union, Any, Optional, Callable
@@ -72,6 +72,9 @@ class FX(nn.Module):
72
 
73
  self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
74
 
 
 
 
75
 
76
  class SmoothingCoef(nn.Module):
77
  def forward(self, x):
@@ -196,6 +199,18 @@ class CompressorExpander(FX):
196
  s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
197
  return s
198
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def forward(self, x):
200
  if self.lookahead:
201
  lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
@@ -230,6 +245,11 @@ class Panning(FX):
230
  s = f"pan: {self.params.pan.item() * 200 - 100}"
231
  return s
232
 
 
 
 
 
 
233
  def forward(self, x: torch.Tensor):
234
  angle = self.params.pan.view(1) * torch.pi * 0.5
235
  amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
@@ -310,6 +330,12 @@ class LowPass(FX):
310
  s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
311
  return s
312
 
 
 
 
 
 
 
313
 
314
  class HighPass(LowPass):
315
  def __init__(
@@ -363,6 +389,13 @@ class Peak(FX):
363
  s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
364
  return s
365
 
 
 
 
 
 
 
 
366
 
367
  class LowShelf(FX):
368
  def __init__(
@@ -394,6 +427,12 @@ class LowShelf(FX):
394
  s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
395
  return s
396
 
 
 
 
 
 
 
397
 
398
  class HighShelf(LowShelf):
399
  def __init__(
@@ -611,6 +650,15 @@ class Delay(FX):
611
  )
612
  return s
613
 
 
 
 
 
 
 
 
 
 
614
 
615
  class SurrogateDelay(Delay):
616
  def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
@@ -992,3 +1040,17 @@ class FDN(FX):
992
  F.pad(x, (self.ir_length - 1, 0)),
993
  h.flip(-1),
994
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from torch import nn
3
  import torch.nn.functional as F
4
  from torch.nn.utils.parametrize import register_parametrization
5
+ from torchcomp import ms2coef, coef2ms, db2amp, amp2db
6
  from torchaudio.transforms import Spectrogram, InverseSpectrogram
7
 
8
  from typing import List, Tuple, Union, Any, Optional, Callable
 
72
 
73
  self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()})
74
 
75
+ def toJSON(self) -> dict[str, Any]:
76
+ return {k: v.item() for k, v in self.params.items() if v.numel() == 1}
77
+
78
 
79
  class SmoothingCoef(nn.Module):
80
  def forward(self, x):
 
199
  s += f"\nlookahead: {self.params.lookahead.item()} (ms)"
200
  return s
201
 
202
+ def toJSON(self) -> dict[str, Any]:
203
+ return {
204
+ "Attack (ms)": coef2ms(self.params.at, self.sr).item(),
205
+ "Release (ms)": coef2ms(self.params.rt, self.sr).item(),
206
+ "Average Coefficient": self.params.avg_coef.item(),
207
+ "Compressor Ratio": self.params.cmp_ratio.item(),
208
+ "Expander Ratio": self.params.exp_ratio.item(),
209
+ "Compressor Threshold (dB)": self.params.cmp_th.item(),
210
+ "Expander Threshold (dB)": self.params.exp_th.item(),
211
+ "Make Up (dB)": self.params.make_up.item(),
212
+ } | ({"Lookahead (ms)": self.params.lookahead.item()} if self.lookahead else {})
213
+
214
  def forward(self, x):
215
  if self.lookahead:
216
  lookahead_in_samples = self.params.lookahead * 0.001 * self.sr
 
245
  s = f"pan: {self.params.pan.item() * 200 - 100}"
246
  return s
247
 
248
+ def toJSON(self) -> dict[str, Any]:
249
+ return {
250
+ "Pan": self.params.pan.item() * 200 - 100,
251
+ }
252
+
253
  def forward(self, x: torch.Tensor):
254
  angle = self.params.pan.view(1) * torch.pi * 0.5
255
  amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM
 
330
  s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}"
331
  return s
332
 
333
+ def toJSON(self) -> dict[str, Any]:
334
+ return {
335
+ "Frequency (Hz)": self.params.freq.item(),
336
+ "Q": self.params.Q.item(),
337
+ }
338
+
339
 
340
  class HighPass(LowPass):
341
  def __init__(
 
389
  s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}"
390
  return s
391
 
392
+ def toJSON(self) -> dict[str, Any]:
393
+ return {
394
+ "Frequency (Hz)": self.params.freq.item(),
395
+ "Gain (dB)": self.params.gain.item(),
396
+ "Q": self.params.Q.item(),
397
+ }
398
+
399
 
400
  class LowShelf(FX):
401
  def __init__(
 
427
  s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}"
428
  return s
429
 
430
+ def toJSON(self) -> dict[str, Any]:
431
+ return {
432
+ "Frequency (Hz)": self.params.freq.item(),
433
+ "Gain (dB)": self.params.gain.item(),
434
+ }
435
+
436
 
437
  class HighShelf(LowShelf):
438
  def __init__(
 
650
  )
651
  return s
652
 
653
+ def toJSON(self) -> dict[str, Any]:
654
+ return {
655
+ "Delay (ms)": self.params.delay.item(),
656
+ "Feedback (dB)": self.params.feedback.log10().mul(20).item(),
657
+ "Gain (dB)": self.params.gain.log10().mul(20).item(),
658
+ "Odd delays": self.odd_pan.toJSON(),
659
+ "Even delays": self.even_pan.toJSON(),
660
+ }
661
+
662
 
663
  class SurrogateDelay(Delay):
664
  def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs):
 
1040
  F.pad(x, (self.ir_length - 1, 0)),
1041
  h.flip(-1),
1042
  )
1043
+
1044
+ def toJSON(self) -> dict[str, Any]:
1045
+ return {
1046
+ "T60 (s)": {
1047
+ f"{f:.2f} Hz": g.item()
1048
+ for f, g in zip(
1049
+ torch.linspace(0, 22050, self.params.gamma.numel()),
1050
+ -60 * self.delays.min() / amp2db(self.params.gamma) / 44100,
1051
+ )
1052
+ },
1053
+ "Gain (dB, approx)": amp2db(
1054
+ torch.linalg.norm(self.params.b) * torch.linalg.norm(self.params.c)
1055
+ ).item(),
1056
+ }