Spaces:
Runtime error
Runtime error
yourusername
commited on
Commit
β’
66a6dc0
1
Parent(s):
9f5a755
:beers: cheers
Browse files- app.py +91 -0
- deepafx_st/__init__.py +4 -0
- deepafx_st/callbacks/audio.py +184 -0
- deepafx_st/callbacks/ckpt.py +33 -0
- deepafx_st/callbacks/params.py +87 -0
- deepafx_st/callbacks/plotting.py +126 -0
- deepafx_st/data/audio.py +177 -0
- deepafx_st/data/augmentations.py +235 -0
- deepafx_st/data/dataset.py +344 -0
- deepafx_st/data/proxy.py +181 -0
- deepafx_st/data/style.py +62 -0
- deepafx_st/metrics.py +157 -0
- deepafx_st/models/baselines.py +280 -0
- deepafx_st/models/controller.py +75 -0
- deepafx_st/models/efficient_net/LICENSE +202 -0
- deepafx_st/models/efficient_net/__init__.py +9 -0
- deepafx_st/models/efficient_net/model.py +419 -0
- deepafx_st/models/efficient_net/utils.py +616 -0
- deepafx_st/models/encoder.py +113 -0
- deepafx_st/models/mobilenetv2.py +226 -0
- deepafx_st/probes/cdpam_encoder.py +68 -0
- deepafx_st/probes/probe_system.py +307 -0
- deepafx_st/probes/random_mel.py +93 -0
- deepafx_st/processors/autodiff/__init__.py +0 -0
- deepafx_st/processors/autodiff/channel.py +28 -0
- deepafx_st/processors/autodiff/compressor.py +169 -0
- deepafx_st/processors/autodiff/fir.py +68 -0
- deepafx_st/processors/autodiff/peq.py +274 -0
- deepafx_st/processors/autodiff/signal.py +194 -0
- deepafx_st/processors/dsp/compressor.py +177 -0
- deepafx_st/processors/dsp/peq.py +323 -0
- deepafx_st/processors/processor.py +87 -0
- deepafx_st/processors/proxy/channel.py +130 -0
- deepafx_st/processors/proxy/proxy_system.py +289 -0
- deepafx_st/processors/proxy/tcn.py +199 -0
- deepafx_st/processors/spsa/channel.py +179 -0
- deepafx_st/processors/spsa/eps_scheduler.py +32 -0
- deepafx_st/processors/spsa/spsa_func.py +131 -0
- deepafx_st/system.py +563 -0
- deepafx_st/utils.py +277 -0
- deepafx_st/version.py +6 -0
- packages.txt +3 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import resampy
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
from deepafx_st.system import System
|
9 |
+
from deepafx_st.utils import DSPMode
|
10 |
+
|
11 |
+
system = System.load_from_checkpoint(
|
12 |
+
hf_hub_download("nateraw/deepafx-st-libritts-autodiff", "lit_model.ckpt"), batch_size=1
|
13 |
+
).eval()
|
14 |
+
|
15 |
+
gpu = torch.cuda.is_available()
|
16 |
+
|
17 |
+
if gpu:
|
18 |
+
system.to("cuda")
|
19 |
+
|
20 |
+
|
21 |
+
def process(input_path, reference_path):
|
22 |
+
# load audio data
|
23 |
+
x, x_sr = torchaudio.load(input_path)
|
24 |
+
r, r_sr = torchaudio.load(reference_path)
|
25 |
+
|
26 |
+
# resample if needed
|
27 |
+
if x_sr != 24000:
|
28 |
+
print("Resampling to 24000 Hz...")
|
29 |
+
x_24000 = torch.tensor(resampy.resample(x.view(-1).numpy(), x_sr, 24000))
|
30 |
+
x_24000 = x_24000.view(1, -1)
|
31 |
+
else:
|
32 |
+
x_24000 = x
|
33 |
+
|
34 |
+
if r_sr != 24000:
|
35 |
+
print("Resampling to 24000 Hz...")
|
36 |
+
r_24000 = torch.tensor(resampy.resample(r.view(-1).numpy(), r_sr, 24000))
|
37 |
+
r_24000 = r_24000.view(1, -1)
|
38 |
+
else:
|
39 |
+
r_24000 = r
|
40 |
+
|
41 |
+
# peak normalize to -12 dBFS
|
42 |
+
x_24000 = x_24000[0:1, : 24000 * 5]
|
43 |
+
x_24000 /= x_24000.abs().max()
|
44 |
+
x_24000 *= 10 ** (-12 / 20.0)
|
45 |
+
x_24000 = x_24000.view(1, 1, -1)
|
46 |
+
|
47 |
+
# peak normalize to -12 dBFS
|
48 |
+
r_24000 = r_24000[0:1, : 24000 * 5]
|
49 |
+
r_24000 /= r_24000.abs().max()
|
50 |
+
r_24000 *= 10 ** (-12 / 20.0)
|
51 |
+
r_24000 = r_24000.view(1, 1, -1)
|
52 |
+
|
53 |
+
if gpu:
|
54 |
+
x_24000 = x_24000.to("cuda")
|
55 |
+
r_24000 = r_24000.to("cuda")
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
y_hat, p, e = system(x_24000, r_24000)
|
59 |
+
|
60 |
+
y_hat = y_hat.view(1, -1)
|
61 |
+
y_hat /= y_hat.abs().max()
|
62 |
+
x_24000 /= x_24000.abs().max()
|
63 |
+
|
64 |
+
# Sqeeze to (T,), convert to numpy, and convert to int16
|
65 |
+
out_audio = (32767 * y_hat).squeeze(0).detach().cpu().numpy().astype(np.int16)
|
66 |
+
|
67 |
+
return 24000, out_audio
|
68 |
+
|
69 |
+
|
70 |
+
gr.Interface(
|
71 |
+
fn=process,
|
72 |
+
inputs=[gr.Audio(type="filepath"), gr.Audio(type="filepath")],
|
73 |
+
outputs="audio",
|
74 |
+
examples=[
|
75 |
+
[
|
76 |
+
hf_hub_download("nateraw/examples", "voice_raw.wav", repo_type="dataset", cache_dir="./data"),
|
77 |
+
hf_hub_download("nateraw/examples", "voice_produced.wav", repo_type="dataset", cache_dir="./data"),
|
78 |
+
],
|
79 |
+
],
|
80 |
+
title="DeepAFx-ST",
|
81 |
+
description=(
|
82 |
+
"Gradio demo for DeepAFx-ST for style transfer of audio effects with differentiable signal processing. To use it, simply"
|
83 |
+
" upload your audio files or choose from one of the examples. Read more at the links below."
|
84 |
+
),
|
85 |
+
article=(
|
86 |
+
"<div style='text-align: center;'><a href='https://github.com/adobe-research/DeepAFx-ST' target='_blank'>Github Repo</a>"
|
87 |
+
" <center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_deepafx-st' alt='visitor"
|
88 |
+
" badge'></center></div>"
|
89 |
+
),
|
90 |
+
allow_flagging="never",
|
91 |
+
).launch()
|
deepafx_st/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""Top-level module for deepafx_st"""
|
3 |
+
|
4 |
+
from .version import version as __version__
|
deepafx_st/callbacks/audio.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import auraloss
|
2 |
+
import numpy as np
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
|
5 |
+
from deepafx_st.callbacks.plotting import plot_multi_spectrum
|
6 |
+
from deepafx_st.metrics import (
|
7 |
+
LoudnessError,
|
8 |
+
SpectralCentroidError,
|
9 |
+
CrestFactorError,
|
10 |
+
PESQ,
|
11 |
+
MelSpectralDistance,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class LogAudioCallback(pl.callbacks.Callback):
|
16 |
+
def __init__(self, num_examples=4, peak_normalize=True, sample_rate=22050):
|
17 |
+
super().__init__()
|
18 |
+
self.num_examples = 4
|
19 |
+
self.peak_normalize = peak_normalize
|
20 |
+
|
21 |
+
self.metrics = {
|
22 |
+
"PESQ": PESQ(sample_rate),
|
23 |
+
"MRSTFT": auraloss.freq.MultiResolutionSTFTLoss(
|
24 |
+
fft_sizes=[32, 128, 512, 2048, 8192, 32768],
|
25 |
+
hop_sizes=[16, 64, 256, 1024, 4096, 16384],
|
26 |
+
win_lengths=[32, 128, 512, 2048, 8192, 32768],
|
27 |
+
w_sc=0.0,
|
28 |
+
w_phs=0.0,
|
29 |
+
w_lin_mag=1.0,
|
30 |
+
w_log_mag=1.0,
|
31 |
+
),
|
32 |
+
"MSD": MelSpectralDistance(sample_rate),
|
33 |
+
"SCE": SpectralCentroidError(sample_rate),
|
34 |
+
"CFE": CrestFactorError(),
|
35 |
+
"LUFS": LoudnessError(sample_rate),
|
36 |
+
}
|
37 |
+
|
38 |
+
self.outputs = []
|
39 |
+
|
40 |
+
def on_validation_batch_end(
|
41 |
+
self,
|
42 |
+
trainer,
|
43 |
+
pl_module,
|
44 |
+
outputs,
|
45 |
+
batch,
|
46 |
+
batch_idx,
|
47 |
+
dataloader_idx,
|
48 |
+
):
|
49 |
+
"""Called when the validation batch ends."""
|
50 |
+
|
51 |
+
if outputs is not None:
|
52 |
+
examples = np.min([self.num_examples, outputs["x"].shape[0]])
|
53 |
+
self.outputs.append(outputs)
|
54 |
+
|
55 |
+
if batch_idx == 0:
|
56 |
+
for n in range(examples):
|
57 |
+
if batch_idx == 0:
|
58 |
+
self.log_audio(
|
59 |
+
outputs,
|
60 |
+
n,
|
61 |
+
pl_module.hparams.sample_rate,
|
62 |
+
pl_module.hparams.val_length,
|
63 |
+
trainer.global_step,
|
64 |
+
trainer.logger,
|
65 |
+
)
|
66 |
+
|
67 |
+
def on_validation_end(self, trainer, pl_module):
|
68 |
+
metrics = {
|
69 |
+
"PESQ": [],
|
70 |
+
"MRSTFT": [],
|
71 |
+
"MSD": [],
|
72 |
+
"SCE": [],
|
73 |
+
"CFE": [],
|
74 |
+
"LUFS": [],
|
75 |
+
}
|
76 |
+
for output in self.outputs:
|
77 |
+
for metric_name, metric in self.metrics.items():
|
78 |
+
try:
|
79 |
+
val = metric(output["y_hat"], output["y"])
|
80 |
+
metrics[metric_name].append(val)
|
81 |
+
except:
|
82 |
+
pass
|
83 |
+
|
84 |
+
# log final mean metrics
|
85 |
+
for metric_name, metric in metrics.items():
|
86 |
+
val = np.mean(metric)
|
87 |
+
trainer.logger.experiment.add_scalar(
|
88 |
+
f"metrics/{metric_name}", val, trainer.global_step
|
89 |
+
)
|
90 |
+
|
91 |
+
# clear outputs
|
92 |
+
self.outputs = []
|
93 |
+
|
94 |
+
def compute_metrics(self, metrics_dict, outputs, batch_idx, global_step):
|
95 |
+
# extract audio
|
96 |
+
y = outputs["y"][batch_idx, ...].float()
|
97 |
+
y_hat = outputs["y_hat"][batch_idx, ...].float()
|
98 |
+
|
99 |
+
# compute all metrics
|
100 |
+
for metric_name, metric in self.metrics.items():
|
101 |
+
try:
|
102 |
+
val = metric(y_hat.view(1, 1, -1), y.view(1, 1, -1))
|
103 |
+
metrics_dict[metric_name].append(val)
|
104 |
+
except:
|
105 |
+
pass
|
106 |
+
|
107 |
+
def log_audio(self, outputs, batch_idx, sample_rate, n_fft, global_step, logger):
|
108 |
+
x = outputs["x"][batch_idx, ...].float()
|
109 |
+
y = outputs["y"][batch_idx, ...].float()
|
110 |
+
y_hat = outputs["y_hat"][batch_idx, ...].float()
|
111 |
+
|
112 |
+
if self.peak_normalize:
|
113 |
+
x /= x.abs().max()
|
114 |
+
y /= y.abs().max()
|
115 |
+
y_hat /= y_hat.abs().max()
|
116 |
+
|
117 |
+
logger.experiment.add_audio(
|
118 |
+
f"x/{batch_idx+1}",
|
119 |
+
x[0:1, :],
|
120 |
+
global_step,
|
121 |
+
sample_rate=sample_rate,
|
122 |
+
)
|
123 |
+
|
124 |
+
logger.experiment.add_audio(
|
125 |
+
f"y/{batch_idx+1}",
|
126 |
+
y[0:1, :],
|
127 |
+
global_step,
|
128 |
+
sample_rate=sample_rate,
|
129 |
+
)
|
130 |
+
|
131 |
+
logger.experiment.add_audio(
|
132 |
+
f"y_hat/{batch_idx+1}",
|
133 |
+
y_hat[0:1, :],
|
134 |
+
global_step,
|
135 |
+
sample_rate=sample_rate,
|
136 |
+
)
|
137 |
+
|
138 |
+
if "y_ref" in outputs:
|
139 |
+
y_ref = outputs["y_ref"][batch_idx, ...].float()
|
140 |
+
|
141 |
+
if self.peak_normalize:
|
142 |
+
y_ref /= y_ref.abs().max()
|
143 |
+
|
144 |
+
logger.experiment.add_audio(
|
145 |
+
f"y_ref/{batch_idx+1}",
|
146 |
+
y_ref[0:1, :],
|
147 |
+
global_step,
|
148 |
+
sample_rate=sample_rate,
|
149 |
+
)
|
150 |
+
logger.experiment.add_image(
|
151 |
+
f"spec/{batch_idx+1}",
|
152 |
+
compare_spectra(
|
153 |
+
y_hat[0:1, :],
|
154 |
+
y[0:1, :],
|
155 |
+
x[0:1, :],
|
156 |
+
sample_rate=sample_rate,
|
157 |
+
n_fft=n_fft,
|
158 |
+
),
|
159 |
+
global_step,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def compare_spectra(
|
164 |
+
deepafx_y_hat, y, x, baseline_y_hat=None, sample_rate=44100, n_fft=16384
|
165 |
+
):
|
166 |
+
legend = ["Corrupted"]
|
167 |
+
signals = [x]
|
168 |
+
if baseline_y_hat is not None:
|
169 |
+
legend.append("Baseline")
|
170 |
+
signals.append(baseline_y_hat)
|
171 |
+
|
172 |
+
legend.append("DeepAFx")
|
173 |
+
signals.append(deepafx_y_hat)
|
174 |
+
legend.append("Target")
|
175 |
+
signals.append(y)
|
176 |
+
|
177 |
+
image = plot_multi_spectrum(
|
178 |
+
ys=signals,
|
179 |
+
legend=legend,
|
180 |
+
sample_rate=sample_rate,
|
181 |
+
n_fft=n_fft,
|
182 |
+
)
|
183 |
+
|
184 |
+
return image
|
deepafx_st/callbacks/ckpt.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import shutil
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
|
7 |
+
class CopyPretrainedCheckpoints(pl.callbacks.Callback):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
def on_fit_start(self, trainer, pl_module):
|
12 |
+
"""Before training, move the pre-trained checkpoints
|
13 |
+
to the current checkpoint directory.
|
14 |
+
|
15 |
+
"""
|
16 |
+
# copy any pre-trained checkpoints to new directory
|
17 |
+
if pl_module.hparams.processor_model == "proxy":
|
18 |
+
pretrained_ckpt_dir = os.path.join(
|
19 |
+
pl_module.logger.experiment.log_dir, "pretrained_checkpoints"
|
20 |
+
)
|
21 |
+
if not os.path.isdir(pretrained_ckpt_dir):
|
22 |
+
os.makedirs(pretrained_ckpt_dir)
|
23 |
+
cp_proxy_ckpts = []
|
24 |
+
for proxy_ckpt in pl_module.hparams.proxy_ckpts:
|
25 |
+
new_ckpt = shutil.copy(
|
26 |
+
proxy_ckpt,
|
27 |
+
pretrained_ckpt_dir,
|
28 |
+
)
|
29 |
+
cp_proxy_ckpts.append(new_ckpt)
|
30 |
+
print(f"Moved checkpoint to {new_ckpt}.")
|
31 |
+
# overwrite to the paths in current experiment logs
|
32 |
+
pl_module.hparams.proxy_ckpts = cp_proxy_ckpts
|
33 |
+
print(pl_module.hparams.proxy_ckpts)
|
deepafx_st/callbacks/params.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
import deepafx_st.utils as utils
|
6 |
+
|
7 |
+
|
8 |
+
class LogParametersCallback(pl.callbacks.Callback):
|
9 |
+
def __init__(self, num_examples=4):
|
10 |
+
super().__init__()
|
11 |
+
self.num_examples = 4
|
12 |
+
|
13 |
+
def on_validation_epoch_start(self, trainer, pl_module):
|
14 |
+
"""At the start of validation init storage for parameters."""
|
15 |
+
self.params = []
|
16 |
+
|
17 |
+
def on_validation_batch_end(
|
18 |
+
self,
|
19 |
+
trainer,
|
20 |
+
pl_module,
|
21 |
+
outputs,
|
22 |
+
batch,
|
23 |
+
batch_idx,
|
24 |
+
dataloader_idx,
|
25 |
+
):
|
26 |
+
"""Called when the validation batch ends.
|
27 |
+
|
28 |
+
Here we log the parameters only from the first batch.
|
29 |
+
|
30 |
+
"""
|
31 |
+
if outputs is not None and batch_idx == 0:
|
32 |
+
examples = np.min([self.num_examples, outputs["x"].shape[0]])
|
33 |
+
for n in range(examples):
|
34 |
+
self.log_parameters(
|
35 |
+
outputs,
|
36 |
+
n,
|
37 |
+
pl_module.processor.ports,
|
38 |
+
trainer.global_step,
|
39 |
+
trainer.logger,
|
40 |
+
True if batch_idx == 0 else False,
|
41 |
+
)
|
42 |
+
|
43 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def log_parameters(self, outputs, batch_idx, ports, global_step, logger, log=True):
|
47 |
+
p = outputs["p"][batch_idx, ...]
|
48 |
+
|
49 |
+
table = ""
|
50 |
+
|
51 |
+
# table += f"""## {plugin["name"]}\n"""
|
52 |
+
table += "| Index| Name | Value | Units | Min | Max | Default | Raw Value | \n"
|
53 |
+
table += "|------|------|------:|:------|----:|----:|--------:| ---------:| \n"
|
54 |
+
|
55 |
+
start_idx = 0
|
56 |
+
# set plugin parameters based on provided normalized parameters
|
57 |
+
for port_list in ports:
|
58 |
+
for pidx, port in enumerate(port_list):
|
59 |
+
param_max = port["max"]
|
60 |
+
param_min = port["min"]
|
61 |
+
param_name = port["name"]
|
62 |
+
param_default = port["default"]
|
63 |
+
param_units = port["units"]
|
64 |
+
|
65 |
+
param_val = p[start_idx]
|
66 |
+
denorm_val = utils.denormalize(param_val, param_max, param_min)
|
67 |
+
|
68 |
+
# add values to table in row
|
69 |
+
table += f"| {start_idx + 1} | {param_name} "
|
70 |
+
if np.abs(denorm_val) > 10:
|
71 |
+
table += f"| {denorm_val:0.1f} "
|
72 |
+
table += f"| {param_units} "
|
73 |
+
table += f"| {param_min:0.1f} | {param_max:0.1f} "
|
74 |
+
table += f"| {param_default:0.1f} "
|
75 |
+
else:
|
76 |
+
table += f"| {denorm_val:0.3f} "
|
77 |
+
table += f"| {param_units} "
|
78 |
+
table += f"| {param_min:0.3f} | {param_max:0.3f} "
|
79 |
+
table += f"| {param_default:0.3f} "
|
80 |
+
|
81 |
+
table += f"| {np.squeeze(param_val):0.2f} | \n"
|
82 |
+
start_idx += 1
|
83 |
+
|
84 |
+
table += "\n\n"
|
85 |
+
|
86 |
+
if log:
|
87 |
+
logger.experiment.add_text(f"params/{batch_idx+1}", table, global_step)
|
deepafx_st/callbacks/plotting.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import torch
|
3 |
+
import PIL.Image
|
4 |
+
import numpy as np
|
5 |
+
import scipy.signal
|
6 |
+
import librosa.display
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
from torch.functional import Tensor
|
10 |
+
from torchvision.transforms import ToTensor
|
11 |
+
|
12 |
+
|
13 |
+
def compute_comparison_spectrogram(
|
14 |
+
x: np.ndarray,
|
15 |
+
y: np.ndarray,
|
16 |
+
sample_rate: float = 44100,
|
17 |
+
n_fft: int = 2048,
|
18 |
+
hop_length: int = 1024,
|
19 |
+
) -> Tensor:
|
20 |
+
X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length)
|
21 |
+
X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max)
|
22 |
+
|
23 |
+
Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
|
24 |
+
Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max)
|
25 |
+
|
26 |
+
fig, axs = plt.subplots(figsize=(9, 6), nrows=2)
|
27 |
+
img = librosa.display.specshow(
|
28 |
+
X_db,
|
29 |
+
ax=axs[0],
|
30 |
+
hop_length=hop_length,
|
31 |
+
x_axis="time",
|
32 |
+
y_axis="log",
|
33 |
+
sr=sample_rate,
|
34 |
+
)
|
35 |
+
# fig.colorbar(img, ax=axs[0])
|
36 |
+
img = librosa.display.specshow(
|
37 |
+
Y_db,
|
38 |
+
ax=axs[1],
|
39 |
+
hop_length=hop_length,
|
40 |
+
x_axis="time",
|
41 |
+
y_axis="log",
|
42 |
+
sr=sample_rate,
|
43 |
+
)
|
44 |
+
# fig.colorbar(img, ax=axs[1])
|
45 |
+
|
46 |
+
plt.tight_layout()
|
47 |
+
|
48 |
+
buf = io.BytesIO()
|
49 |
+
plt.savefig(buf, format="jpeg")
|
50 |
+
buf.seek(0)
|
51 |
+
image = PIL.Image.open(buf)
|
52 |
+
image = ToTensor()(image)
|
53 |
+
plt.close("all")
|
54 |
+
|
55 |
+
return image
|
56 |
+
|
57 |
+
|
58 |
+
def plot_multi_spectrum(
|
59 |
+
ys=None,
|
60 |
+
Hs=None,
|
61 |
+
legend=[],
|
62 |
+
title="Spectrum",
|
63 |
+
filename=None,
|
64 |
+
sample_rate=44100,
|
65 |
+
n_fft=1024,
|
66 |
+
zero_mean=False,
|
67 |
+
):
|
68 |
+
|
69 |
+
if Hs is None:
|
70 |
+
Hs = []
|
71 |
+
for y in ys:
|
72 |
+
X = get_average_spectrum(y, n_fft)
|
73 |
+
X_sm = smooth_spectrum(X)
|
74 |
+
Hs.append(X_sm)
|
75 |
+
|
76 |
+
bin_width = (sample_rate / 2) / (n_fft // 2)
|
77 |
+
freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width)
|
78 |
+
|
79 |
+
fig, ax1 = plt.subplots()
|
80 |
+
|
81 |
+
for idx, H in enumerate(Hs):
|
82 |
+
H = np.nan_to_num(H)
|
83 |
+
H = np.clip(H, 0, np.max(H))
|
84 |
+
H_dB = 20 * np.log10(H + 1e-8)
|
85 |
+
if zero_mean:
|
86 |
+
H_dB -= np.mean(H_dB)
|
87 |
+
if "Target" in legend[idx]:
|
88 |
+
ax1.plot(freqs, H_dB, linestyle="--", color="k")
|
89 |
+
else:
|
90 |
+
ax1.plot(freqs, H_dB)
|
91 |
+
|
92 |
+
plt.legend(legend)
|
93 |
+
|
94 |
+
ax1.set_xscale("log")
|
95 |
+
ax1.set_ylim([-80, 0])
|
96 |
+
ax1.set_xlim([100, 11000])
|
97 |
+
plt.title(title)
|
98 |
+
plt.ylabel("Magnitude (dB)")
|
99 |
+
plt.xlabel("Frequency (Hz)")
|
100 |
+
plt.grid(c="lightgray", which="both")
|
101 |
+
|
102 |
+
if filename is not None:
|
103 |
+
plt.savefig(f"{filename}.png", dpi=300)
|
104 |
+
|
105 |
+
plt.tight_layout()
|
106 |
+
|
107 |
+
buf = io.BytesIO()
|
108 |
+
plt.savefig(buf, format="jpeg")
|
109 |
+
buf.seek(0)
|
110 |
+
image = PIL.Image.open(buf)
|
111 |
+
image = ToTensor()(image)
|
112 |
+
plt.close("all")
|
113 |
+
|
114 |
+
return image
|
115 |
+
|
116 |
+
|
117 |
+
def smooth_spectrum(H):
|
118 |
+
# apply Savgol filter for smoothed target curve
|
119 |
+
return scipy.signal.savgol_filter(H, 1025, 2)
|
120 |
+
|
121 |
+
|
122 |
+
def get_average_spectrum(x, n_fft):
|
123 |
+
X = torch.stft(x, n_fft, return_complex=True, normalized=True)
|
124 |
+
X = X.abs() # convert to magnitude
|
125 |
+
X = X.mean(dim=-1).view(-1) # average across frames
|
126 |
+
return X
|
deepafx_st/data/audio.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import warnings
|
5 |
+
import torchaudio
|
6 |
+
import pyloudnorm as pyln
|
7 |
+
|
8 |
+
|
9 |
+
class AudioFile(object):
|
10 |
+
def __init__(self, filepath, preload=False, half=False, target_loudness=None):
|
11 |
+
"""Base class for audio files to handle metadata and loading.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
filepath (str): Path to audio file to load from disk.
|
15 |
+
preload (bool, optional): If set, load audio data into RAM. Default: False
|
16 |
+
half (bool, optional): If set, store audio data as float16 to save space. Default: False
|
17 |
+
target_loudness (float, optional): Loudness normalize to dB LUFS value. Default:
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.filepath = filepath
|
22 |
+
self.half = half
|
23 |
+
self.target_loudness = target_loudness
|
24 |
+
self.loaded = False
|
25 |
+
|
26 |
+
if preload:
|
27 |
+
self.load()
|
28 |
+
num_frames = self.audio.shape[-1]
|
29 |
+
num_channels = self.audio.shape[0]
|
30 |
+
else:
|
31 |
+
metadata = torchaudio.info(filepath)
|
32 |
+
audio = None
|
33 |
+
self.sample_rate = metadata.sample_rate
|
34 |
+
num_frames = metadata.num_frames
|
35 |
+
num_channels = metadata.num_channels
|
36 |
+
|
37 |
+
self.num_frames = num_frames
|
38 |
+
self.num_channels = num_channels
|
39 |
+
|
40 |
+
def load(self):
|
41 |
+
audio, sr = torchaudio.load(self.filepath, normalize=True)
|
42 |
+
self.audio = audio
|
43 |
+
self.sample_rate = sr
|
44 |
+
|
45 |
+
if self.target_loudness is not None:
|
46 |
+
self.loudness_normalize()
|
47 |
+
|
48 |
+
if self.half:
|
49 |
+
self.audio = audio.half()
|
50 |
+
|
51 |
+
self.loaded = True
|
52 |
+
|
53 |
+
def loudness_normalize(self):
|
54 |
+
meter = pyln.Meter(self.sample_rate)
|
55 |
+
|
56 |
+
# conver mono to stereo
|
57 |
+
if self.audio.shape[0] == 1:
|
58 |
+
tmp_audio = self.audio.repeat(2, 1)
|
59 |
+
else:
|
60 |
+
tmp_audio = self.audio
|
61 |
+
|
62 |
+
# measure integrated loudness
|
63 |
+
input_loudness = meter.integrated_loudness(tmp_audio.numpy().T)
|
64 |
+
|
65 |
+
# compute and apply gain
|
66 |
+
gain_dB = self.target_loudness - input_loudness
|
67 |
+
gain_ln = 10 ** (gain_dB / 20.0)
|
68 |
+
self.audio *= gain_ln
|
69 |
+
|
70 |
+
# check for potentially clipped samples
|
71 |
+
if self.audio.abs().max() >= 1.0:
|
72 |
+
warnings.warn("Possible clipped samples in output.")
|
73 |
+
|
74 |
+
|
75 |
+
class AudioFileDataset(torch.utils.data.Dataset):
|
76 |
+
"""Base class for audio file datasets loaded from disk.
|
77 |
+
|
78 |
+
Datasets can be either paired or unpaired. A paired dataset requires passing the `target_dir` path.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
input_dir (List[str]): List of paths to the directories containing input audio files.
|
82 |
+
target_dir (List[str], optional): List of paths to the directories containing correponding audio files. Default: []
|
83 |
+
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
|
84 |
+
length (int, optional): Number of samples to load for each example. Default: 65536
|
85 |
+
normalize (bool, optional): Normalize audio amplitiude to -1 to 1. Default: True
|
86 |
+
train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
|
87 |
+
val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
|
88 |
+
preload (bool, optional): Read audio files into RAM at the start of training. Default: False
|
89 |
+
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
|
90 |
+
ext (str, optional): Expected audio file extension. Default: "wav"
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
input_dirs,
|
96 |
+
target_dirs=[],
|
97 |
+
subset="train",
|
98 |
+
length=65536,
|
99 |
+
normalize=True,
|
100 |
+
train_per=0.8,
|
101 |
+
val_per=0.1,
|
102 |
+
preload=False,
|
103 |
+
num_examples_per_epoch=10000,
|
104 |
+
ext="wav",
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.input_dirs = input_dirs
|
108 |
+
self.target_dirs = target_dirs
|
109 |
+
self.subset = subset
|
110 |
+
self.length = length
|
111 |
+
self.normalize = normalize
|
112 |
+
self.train_per = train_per
|
113 |
+
self.val_per = val_per
|
114 |
+
self.preload = preload
|
115 |
+
self.num_examples_per_epoch = num_examples_per_epoch
|
116 |
+
self.ext = ext
|
117 |
+
|
118 |
+
self.input_filepaths = []
|
119 |
+
for input_dir in input_dirs:
|
120 |
+
search_path = os.path.join(input_dir, f"*.{ext}")
|
121 |
+
self.input_filepaths += glob.glob(search_path)
|
122 |
+
self.input_filepaths = sorted(self.input_filepaths)
|
123 |
+
|
124 |
+
self.target_filepaths = []
|
125 |
+
for target_dir in target_dirs:
|
126 |
+
search_path = os.path.join(target_dir, f"*.{ext}")
|
127 |
+
self.target_filepaths += glob.glob(search_path)
|
128 |
+
self.target_filepaths = sorted(self.target_filepaths)
|
129 |
+
|
130 |
+
# both sets must have same number of files in paired dataset
|
131 |
+
assert len(self.target_filepaths) == len(self.input_filepaths)
|
132 |
+
|
133 |
+
# get details about audio files
|
134 |
+
self.input_files = []
|
135 |
+
for input_filepath in self.input_filepaths:
|
136 |
+
self.input_files.append(
|
137 |
+
AudioFile(input_filepath, preload=preload, normalize=normalize)
|
138 |
+
)
|
139 |
+
|
140 |
+
self.target_files = []
|
141 |
+
if target_dir is not None:
|
142 |
+
for target_filepath in self.target_filepaths:
|
143 |
+
self.target_files.append(
|
144 |
+
AudioFile(target_filepath, preload=preload, normalize=normalize)
|
145 |
+
)
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return self.num_examples_per_epoch
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
""" """
|
152 |
+
|
153 |
+
# index the current audio file
|
154 |
+
input_file = self.input_files[idx]
|
155 |
+
|
156 |
+
# load the audio data if needed
|
157 |
+
if not input_file.loaded:
|
158 |
+
input_file.load()
|
159 |
+
|
160 |
+
# get a random patch of size `self.length`
|
161 |
+
start_idx = int(torch.rand() * (input_file.num_frames - self.length))
|
162 |
+
stop_idx = start_idx + self.length
|
163 |
+
input_audio = input_file.audio[:, start_idx:stop_idx]
|
164 |
+
|
165 |
+
# if there is a target file, get it (and load)
|
166 |
+
if len(self.target_files) > 0:
|
167 |
+
target_file = self.target_files[idx]
|
168 |
+
|
169 |
+
if not target_file.loaded:
|
170 |
+
target_file.load()
|
171 |
+
|
172 |
+
# use the same cropping indices
|
173 |
+
target_audio = target_file.audio[:, start_idx:stop_idx]
|
174 |
+
|
175 |
+
return input_audio, target_audio
|
176 |
+
else:
|
177 |
+
return input_audio
|
deepafx_st/data/augmentations.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def gain(xs, min_dB=-12, max_dB=12):
|
7 |
+
|
8 |
+
gain_dB = (torch.rand(1) * (max_dB - min_dB)) + min_dB
|
9 |
+
gain_ln = 10 ** (gain_dB / 20)
|
10 |
+
|
11 |
+
for idx, x in enumerate(xs):
|
12 |
+
xs[idx] = x * gain_ln
|
13 |
+
|
14 |
+
return xs
|
15 |
+
|
16 |
+
|
17 |
+
def peaking_filter(xs, sr=44100, frequency=1000, width_q=0.707, gain_db=12):
|
18 |
+
|
19 |
+
# gain_db = ((torch.rand(1) * 6) + 6).numpy().squeeze()
|
20 |
+
# width_q = (torch.rand(1) * 4).numpy().squeeze()
|
21 |
+
# frequency = ((torch.rand(1) * 9960) + 40).numpy().squeeze()
|
22 |
+
|
23 |
+
# if torch.rand(1) > 0.5:
|
24 |
+
# gain_db = -gain_db
|
25 |
+
|
26 |
+
effects = [["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"]]
|
27 |
+
|
28 |
+
for idx, x in enumerate(xs):
|
29 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
30 |
+
x, sr, effects, channels_first=True
|
31 |
+
)
|
32 |
+
xs[idx] = y
|
33 |
+
|
34 |
+
return xs
|
35 |
+
|
36 |
+
|
37 |
+
def pitch_shift(xs, min_shift=-200, max_shift=200, sr=44100):
|
38 |
+
|
39 |
+
shift = min_shift + (torch.rand(1)).numpy().squeeze() * (max_shift - min_shift)
|
40 |
+
|
41 |
+
effects = [["pitch", f"{shift}"]]
|
42 |
+
|
43 |
+
for idx, x in enumerate(xs):
|
44 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
45 |
+
x, sr, effects, channels_first=True
|
46 |
+
)
|
47 |
+
xs[idx] = y
|
48 |
+
|
49 |
+
return xs
|
50 |
+
|
51 |
+
|
52 |
+
def time_stretch(xs, min_stretch=0.8, max_stretch=1.2, sr=44100):
|
53 |
+
|
54 |
+
stretch = min_stretch + (torch.rand(1)).numpy().squeeze() * (
|
55 |
+
max_stretch - min_stretch
|
56 |
+
)
|
57 |
+
|
58 |
+
effects = [["tempo", f"{stretch}"]]
|
59 |
+
for idx, x in enumerate(xs):
|
60 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
61 |
+
x, sr, effects, channels_first=True
|
62 |
+
)
|
63 |
+
xs[idx] = y
|
64 |
+
|
65 |
+
return xs
|
66 |
+
|
67 |
+
|
68 |
+
def frequency_corruption(xs, sr=44100):
|
69 |
+
|
70 |
+
effects = []
|
71 |
+
|
72 |
+
# apply a random number of peaking bands from 0 to 4s
|
73 |
+
bands = [[200, 2000], [800, 4000], [2000, 8000], [4000, int((sr // 2) * 0.9)]]
|
74 |
+
total_gain_db = 0.0
|
75 |
+
for band in bands:
|
76 |
+
if torch.rand(1).sum() > 0.2:
|
77 |
+
frequency = (torch.randint(band[0], band[1], [1])).numpy().squeeze()
|
78 |
+
width_q = ((torch.rand(1) * 10) + 0.1).numpy().squeeze()
|
79 |
+
gain_db = ((torch.rand(1) * 48)).numpy().squeeze()
|
80 |
+
|
81 |
+
if torch.rand(1).sum() > 0.5:
|
82 |
+
gain_db = -gain_db
|
83 |
+
|
84 |
+
total_gain_db += gain_db
|
85 |
+
|
86 |
+
if np.abs(total_gain_db) >= 24:
|
87 |
+
continue
|
88 |
+
|
89 |
+
cmd = ["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"]
|
90 |
+
effects.append(cmd)
|
91 |
+
|
92 |
+
# low shelf (bass)
|
93 |
+
if torch.rand(1).sum() > 0.2:
|
94 |
+
gain_db = ((torch.rand(1) * 24)).numpy().squeeze()
|
95 |
+
frequency = (torch.randint(20, 200, [1])).numpy().squeeze()
|
96 |
+
if torch.rand(1).sum() > 0.5:
|
97 |
+
gain_db = -gain_db
|
98 |
+
effects.append(["bass", f"{gain_db}", f"{frequency}"])
|
99 |
+
|
100 |
+
# high shelf (treble)
|
101 |
+
if torch.rand(1).sum() > 0.2:
|
102 |
+
gain_db = ((torch.rand(1) * 24)).numpy().squeeze()
|
103 |
+
frequency = (torch.randint(4000, int((sr // 2) * 0.9), [1])).numpy().squeeze()
|
104 |
+
if torch.rand(1).sum() > 0.5:
|
105 |
+
gain_db = -gain_db
|
106 |
+
effects.append(["treble", f"{gain_db}", f"{frequency}"])
|
107 |
+
|
108 |
+
for idx, x in enumerate(xs):
|
109 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
110 |
+
x.view(1, -1) * 10 ** (-48 / 20), sr, effects, channels_first=True
|
111 |
+
)
|
112 |
+
# apply gain back
|
113 |
+
y *= 10 ** (48 / 20)
|
114 |
+
|
115 |
+
xs[idx] = y
|
116 |
+
|
117 |
+
return xs
|
118 |
+
|
119 |
+
|
120 |
+
def dynamic_range_corruption(xs, sr=44100):
|
121 |
+
"""Apply an expander."""
|
122 |
+
|
123 |
+
attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.001
|
124 |
+
release = (torch.rand([1]).numpy()[0] * 0.2) + attack
|
125 |
+
knee = (torch.rand([1]).numpy()[0] * 12) + 0.0
|
126 |
+
|
127 |
+
# design the compressor transfer function
|
128 |
+
start = -100.0
|
129 |
+
threshold = -(
|
130 |
+
(torch.rand([1]).numpy()[0] * 20) + 10
|
131 |
+
) # threshold from -30 to -10 dB
|
132 |
+
ratio = (torch.rand([1]).numpy()[0] * 4.0) + 1 # ratio from 1:1 to 5:1
|
133 |
+
|
134 |
+
# compute the transfer curve
|
135 |
+
point = -((-threshold / -ratio) + (-start / ratio) + -threshold)
|
136 |
+
|
137 |
+
# apply some makeup gain
|
138 |
+
makeup = torch.rand([1]).numpy()[0] * 6
|
139 |
+
|
140 |
+
effects = [
|
141 |
+
[
|
142 |
+
"compand",
|
143 |
+
f"{attack},{release}",
|
144 |
+
f"{knee}:{point},{start},{threshold},{threshold}",
|
145 |
+
f"{makeup}",
|
146 |
+
f"{start}",
|
147 |
+
]
|
148 |
+
]
|
149 |
+
|
150 |
+
for idx, x in enumerate(xs):
|
151 |
+
# if the input is clipping normalize it
|
152 |
+
if x.abs().max() >= 1.0:
|
153 |
+
x /= x.abs().max()
|
154 |
+
gain_db = -((torch.rand(1) * 24)).numpy().squeeze()
|
155 |
+
x *= 10 ** (gain_db / 20.0)
|
156 |
+
|
157 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
158 |
+
x.view(1, -1), sr, effects, channels_first=True
|
159 |
+
)
|
160 |
+
xs[idx] = y
|
161 |
+
|
162 |
+
return xs
|
163 |
+
|
164 |
+
|
165 |
+
def dynamic_range_compression(xs, sr=44100):
|
166 |
+
"""Apply a compressor."""
|
167 |
+
|
168 |
+
attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.0005
|
169 |
+
release = (torch.rand([1]).numpy()[0] * 0.2) + attack
|
170 |
+
knee = (torch.rand([1]).numpy()[0] * 12) + 0.0
|
171 |
+
|
172 |
+
# design the compressor transfer function
|
173 |
+
start = -100.0
|
174 |
+
threshold = -((torch.rand([1]).numpy()[0] * 52) + 12)
|
175 |
+
# threshold from -64 to -12 dB
|
176 |
+
ratio = (torch.rand([1]).numpy()[0] * 10.0) + 1 # ratio from 1:1 to 10:1
|
177 |
+
|
178 |
+
# compute the transfer curve
|
179 |
+
point = threshold * (1 - (1 / ratio))
|
180 |
+
|
181 |
+
# apply some makeup gain
|
182 |
+
makeup = torch.rand([1]).numpy()[0] * 6
|
183 |
+
|
184 |
+
effects = [
|
185 |
+
[
|
186 |
+
"compand",
|
187 |
+
f"{attack},{release}",
|
188 |
+
f"{knee}:{start},{threshold},{threshold},0,{point}",
|
189 |
+
f"{makeup}",
|
190 |
+
f"{start}",
|
191 |
+
f"{attack}",
|
192 |
+
]
|
193 |
+
]
|
194 |
+
|
195 |
+
for idx, x in enumerate(xs):
|
196 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
197 |
+
x.view(1, -1), sr, effects, channels_first=True
|
198 |
+
)
|
199 |
+
xs[idx] = y
|
200 |
+
|
201 |
+
return xs
|
202 |
+
|
203 |
+
|
204 |
+
def lowpass_filter(xs, sr=44100, frequency=4000):
|
205 |
+
effects = [["lowpass", f"{frequency}"]]
|
206 |
+
|
207 |
+
for idx, x in enumerate(xs):
|
208 |
+
y, sr = torchaudio.sox_effects.apply_effects_tensor(
|
209 |
+
x, sr, effects, channels_first=True
|
210 |
+
)
|
211 |
+
xs[idx] = y
|
212 |
+
|
213 |
+
return xs
|
214 |
+
|
215 |
+
|
216 |
+
def apply(xs, sr, augmentations):
|
217 |
+
|
218 |
+
# iterate over augmentation dict
|
219 |
+
for aug, params in augmentations.items():
|
220 |
+
if aug == "gain":
|
221 |
+
xs = gain(xs, **params)
|
222 |
+
elif aug == "peak":
|
223 |
+
xs = peaking_filter(xs, **params)
|
224 |
+
elif aug == "lowpass":
|
225 |
+
xs = lowpass_filter(xs, **params)
|
226 |
+
elif aug == "pitch":
|
227 |
+
xs = pitch_shift(xs, **params)
|
228 |
+
elif aug == "tempo":
|
229 |
+
xs = time_stretch(xs, **params)
|
230 |
+
elif aug == "freq_corrupt":
|
231 |
+
xs = frequency_corruption(xs, **params)
|
232 |
+
else:
|
233 |
+
raise RuntimeError("Invalid augmentation: {aug}")
|
234 |
+
|
235 |
+
return xs
|
deepafx_st/data/dataset.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import csv
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
from typing import List, Any
|
9 |
+
|
10 |
+
from deepafx_st.data.audio import AudioFile
|
11 |
+
import deepafx_st.utils as utils
|
12 |
+
import deepafx_st.data.augmentations as augmentations
|
13 |
+
|
14 |
+
|
15 |
+
class AudioDataset(torch.utils.data.Dataset):
|
16 |
+
"""Audio dataset which returns an input and target file.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
audio_dir (str): Path to the top level of the audio dataset.
|
20 |
+
input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"]
|
21 |
+
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
|
22 |
+
length (int, optional): Number of samples to load for each example. Default: 65536
|
23 |
+
train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8
|
24 |
+
val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1
|
25 |
+
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
|
26 |
+
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
|
27 |
+
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
|
28 |
+
half (bool, optional): Sotre audio samples as float 16. Default: False
|
29 |
+
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
|
30 |
+
random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False
|
31 |
+
random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False
|
32 |
+
augmentations (dict, optional): List of augmentation types to apply to inputs. Default: []
|
33 |
+
freq_corrupt (bool, optional): Apply bad EQ filters. Default: False
|
34 |
+
drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False
|
35 |
+
ext (str, optional): Expected audio file extension. Default: "wav"
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
audio_dir,
|
41 |
+
input_dirs: List[str] = ["cleanraw"],
|
42 |
+
subset: str = "train",
|
43 |
+
length: int = 65536,
|
44 |
+
train_frac: float = 0.8,
|
45 |
+
val_per: float = 0.1,
|
46 |
+
buffer_size_gb: float = 1.0,
|
47 |
+
buffer_reload_rate: float = 1000,
|
48 |
+
half: bool = False,
|
49 |
+
num_examples_per_epoch: int = 10000,
|
50 |
+
random_scale_input: bool = False,
|
51 |
+
random_scale_target: bool = False,
|
52 |
+
augmentations: dict = {},
|
53 |
+
freq_corrupt: bool = False,
|
54 |
+
drc_corrupt: bool = False,
|
55 |
+
ext: str = "wav",
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.audio_dir = audio_dir
|
59 |
+
self.dataset_name = os.path.basename(audio_dir)
|
60 |
+
self.input_dirs = input_dirs
|
61 |
+
self.subset = subset
|
62 |
+
self.length = length
|
63 |
+
self.train_frac = train_frac
|
64 |
+
self.val_per = val_per
|
65 |
+
self.buffer_size_gb = buffer_size_gb
|
66 |
+
self.buffer_reload_rate = buffer_reload_rate
|
67 |
+
self.half = half
|
68 |
+
self.num_examples_per_epoch = num_examples_per_epoch
|
69 |
+
self.random_scale_input = random_scale_input
|
70 |
+
self.random_scale_target = random_scale_target
|
71 |
+
self.augmentations = augmentations
|
72 |
+
self.freq_corrupt = freq_corrupt
|
73 |
+
self.drc_corrupt = drc_corrupt
|
74 |
+
self.ext = ext
|
75 |
+
|
76 |
+
self.input_filepaths = []
|
77 |
+
for input_dir in input_dirs:
|
78 |
+
search_path = os.path.join(audio_dir, input_dir, f"*.{ext}")
|
79 |
+
self.input_filepaths += glob.glob(search_path)
|
80 |
+
self.input_filepaths = sorted(self.input_filepaths)
|
81 |
+
|
82 |
+
# create dataset split based on subset
|
83 |
+
self.input_filepaths = utils.split_dataset(
|
84 |
+
self.input_filepaths,
|
85 |
+
subset,
|
86 |
+
train_frac,
|
87 |
+
)
|
88 |
+
|
89 |
+
# get details about input audio files
|
90 |
+
input_files = {}
|
91 |
+
input_dur_frames = 0
|
92 |
+
for input_filepath in tqdm(self.input_filepaths, ncols=80):
|
93 |
+
file_id = os.path.basename(input_filepath)
|
94 |
+
audio_file = AudioFile(
|
95 |
+
input_filepath,
|
96 |
+
preload=False,
|
97 |
+
half=half,
|
98 |
+
)
|
99 |
+
if audio_file.num_frames < (self.length * 2):
|
100 |
+
continue
|
101 |
+
input_files[file_id] = audio_file
|
102 |
+
input_dur_frames += input_files[file_id].num_frames
|
103 |
+
|
104 |
+
if len(list(input_files.items())) < 1:
|
105 |
+
raise RuntimeError(f"No files found in {search_path}.")
|
106 |
+
|
107 |
+
input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600
|
108 |
+
print(
|
109 |
+
f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours."
|
110 |
+
)
|
111 |
+
|
112 |
+
self.sample_rate = input_files[file_id].sample_rate
|
113 |
+
|
114 |
+
# save a csv file with details about the train and test split
|
115 |
+
splits_dir = os.path.join("configs", "splits")
|
116 |
+
if not os.path.isdir(splits_dir):
|
117 |
+
os.makedirs(splits_dir)
|
118 |
+
csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv")
|
119 |
+
|
120 |
+
with open(csv_filepath, "w") as fp:
|
121 |
+
dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"])
|
122 |
+
dw.writeheader()
|
123 |
+
for input_filepath in self.input_filepaths:
|
124 |
+
dw.writerow(
|
125 |
+
{
|
126 |
+
"file_id": self.get_file_id(input_filepath),
|
127 |
+
"filepath": input_filepath,
|
128 |
+
"type": "input",
|
129 |
+
"subset": self.subset,
|
130 |
+
}
|
131 |
+
)
|
132 |
+
|
133 |
+
# some setup for iteratble loading of the dataset into RAM
|
134 |
+
self.items_since_load = self.buffer_reload_rate
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return self.num_examples_per_epoch
|
138 |
+
|
139 |
+
def load_audio_buffer(self):
|
140 |
+
self.input_files_loaded = {} # clear audio buffer
|
141 |
+
self.items_since_load = 0 # reset iteration counter
|
142 |
+
nbytes_loaded = 0 # counter for data in RAM
|
143 |
+
|
144 |
+
# different subset in each
|
145 |
+
random.shuffle(self.input_filepaths)
|
146 |
+
|
147 |
+
# load files into RAM
|
148 |
+
for input_filepath in self.input_filepaths:
|
149 |
+
file_id = os.path.basename(input_filepath)
|
150 |
+
audio_file = AudioFile(
|
151 |
+
input_filepath,
|
152 |
+
preload=True,
|
153 |
+
half=self.half,
|
154 |
+
)
|
155 |
+
|
156 |
+
if audio_file.num_frames < (self.length * 2):
|
157 |
+
continue
|
158 |
+
|
159 |
+
self.input_files_loaded[file_id] = audio_file
|
160 |
+
|
161 |
+
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
|
162 |
+
nbytes_loaded += nbytes
|
163 |
+
|
164 |
+
# check the size of loaded data
|
165 |
+
if nbytes_loaded > self.buffer_size_gb * 1e9:
|
166 |
+
break
|
167 |
+
|
168 |
+
def generate_pair(self):
|
169 |
+
# ------------------------ Input audio ----------------------
|
170 |
+
rand_input_file_id = None
|
171 |
+
input_file = None
|
172 |
+
start_idx = None
|
173 |
+
stop_idx = None
|
174 |
+
while True:
|
175 |
+
rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys())
|
176 |
+
|
177 |
+
# use this random key to retrieve an input file
|
178 |
+
input_file = self.input_files_loaded[rand_input_file_id]
|
179 |
+
|
180 |
+
# load the audio data if needed
|
181 |
+
if not input_file.loaded:
|
182 |
+
raise RuntimeError("Audio not loaded.")
|
183 |
+
|
184 |
+
# get a random patch of size `self.length` x 2
|
185 |
+
start_idx, stop_idx = self.get_random_patch(
|
186 |
+
input_file, int(self.length * 2)
|
187 |
+
)
|
188 |
+
if start_idx >= 0:
|
189 |
+
break
|
190 |
+
|
191 |
+
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
|
192 |
+
input_audio = input_audio.view(1, -1)
|
193 |
+
|
194 |
+
if self.half:
|
195 |
+
input_audio = input_audio.float()
|
196 |
+
|
197 |
+
# peak normalize to -12 dBFS
|
198 |
+
input_audio /= input_audio.abs().max()
|
199 |
+
input_audio *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
|
200 |
+
|
201 |
+
if len(list(self.augmentations.items())) > 0:
|
202 |
+
if torch.rand(1).sum() < 0.5:
|
203 |
+
input_audio_aug = augmentations.apply(
|
204 |
+
[input_audio],
|
205 |
+
self.sample_rate,
|
206 |
+
self.augmentations,
|
207 |
+
)[0]
|
208 |
+
else:
|
209 |
+
input_audio_aug = input_audio.clone()
|
210 |
+
else:
|
211 |
+
input_audio_aug = input_audio.clone()
|
212 |
+
|
213 |
+
input_audio_corrupt = input_audio_aug.clone()
|
214 |
+
# apply frequency and dynamic range corrpution (expander)
|
215 |
+
if self.freq_corrupt and torch.rand(1).sum() < 0.75:
|
216 |
+
input_audio_corrupt = augmentations.frequency_corruption(
|
217 |
+
[input_audio_corrupt], self.sample_rate
|
218 |
+
)[0]
|
219 |
+
|
220 |
+
# peak normalize again before passing through dynamic range expander
|
221 |
+
input_audio_corrupt /= input_audio_corrupt.abs().max()
|
222 |
+
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
|
223 |
+
|
224 |
+
if self.drc_corrupt and torch.rand(1).sum() < 0.10:
|
225 |
+
input_audio_corrupt = augmentations.dynamic_range_corruption(
|
226 |
+
[input_audio_corrupt], self.sample_rate
|
227 |
+
)[0]
|
228 |
+
|
229 |
+
# ------------------------ Target audio ----------------------
|
230 |
+
# use the same augmented audio clip, add different random EQ and compressor
|
231 |
+
|
232 |
+
target_audio_corrupt = input_audio_aug.clone()
|
233 |
+
# apply frequency and dynamic range corrpution (expander)
|
234 |
+
if self.freq_corrupt and torch.rand(1).sum() < 0.75:
|
235 |
+
target_audio_corrupt = augmentations.frequency_corruption(
|
236 |
+
[target_audio_corrupt], self.sample_rate
|
237 |
+
)[0]
|
238 |
+
|
239 |
+
# peak normalize again before passing through dynamic range compressor
|
240 |
+
input_audio_corrupt /= input_audio_corrupt.abs().max()
|
241 |
+
input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom
|
242 |
+
|
243 |
+
if self.drc_corrupt and torch.rand(1).sum() < 0.75:
|
244 |
+
target_audio_corrupt = augmentations.dynamic_range_compression(
|
245 |
+
[target_audio_corrupt], self.sample_rate
|
246 |
+
)[0]
|
247 |
+
|
248 |
+
return input_audio_corrupt, target_audio_corrupt
|
249 |
+
|
250 |
+
def __getitem__(self, _):
|
251 |
+
""" """
|
252 |
+
|
253 |
+
# increment counter
|
254 |
+
self.items_since_load += 1
|
255 |
+
|
256 |
+
# load next chunk into buffer if needed
|
257 |
+
if self.items_since_load > self.buffer_reload_rate:
|
258 |
+
self.load_audio_buffer()
|
259 |
+
|
260 |
+
# generate pairs for style training
|
261 |
+
input_audio, target_audio = self.generate_pair()
|
262 |
+
|
263 |
+
# ------------------------ Conform length of files -------------------
|
264 |
+
input_audio = utils.conform_length(input_audio, int(self.length * 2))
|
265 |
+
target_audio = utils.conform_length(target_audio, int(self.length * 2))
|
266 |
+
|
267 |
+
# ------------------------ Apply fade in and fade out -------------------
|
268 |
+
input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate)
|
269 |
+
target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate)
|
270 |
+
|
271 |
+
# ------------------------ Final normalizeation ----------------------
|
272 |
+
# always peak normalize final input to -12 dBFS
|
273 |
+
input_audio /= input_audio.abs().max()
|
274 |
+
input_audio *= 10 ** (-12.0 / 20.0)
|
275 |
+
|
276 |
+
# always peak normalize the target to -12 dBFS
|
277 |
+
target_audio /= target_audio.abs().max()
|
278 |
+
target_audio *= 10 ** (-12.0 / 20.0)
|
279 |
+
|
280 |
+
return input_audio, target_audio
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def get_random_file_id(keys):
|
284 |
+
# generate a random index into the keys of the input files
|
285 |
+
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
|
286 |
+
# find the key (file_id) correponding to the random index
|
287 |
+
rand_input_file_id = list(keys)[rand_input_idx]
|
288 |
+
|
289 |
+
return rand_input_file_id
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def get_random_patch(audio_file, length, check_silence=True):
|
293 |
+
silent = True
|
294 |
+
count = 0
|
295 |
+
while silent:
|
296 |
+
count += 1
|
297 |
+
start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0]
|
298 |
+
# int(torch.rand(1) * (audio_file.num_frames - length))
|
299 |
+
stop_idx = start_idx + length
|
300 |
+
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
|
301 |
+
|
302 |
+
length = patch.shape[-1]
|
303 |
+
first_patch = patch[..., : length // 2]
|
304 |
+
second_patch = patch[..., length // 2 :]
|
305 |
+
|
306 |
+
if (
|
307 |
+
(first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5
|
308 |
+
) or not check_silence:
|
309 |
+
silent = False
|
310 |
+
|
311 |
+
if count > 100:
|
312 |
+
print("get_random_patch count", count)
|
313 |
+
return -1, -1
|
314 |
+
# break
|
315 |
+
|
316 |
+
return start_idx, stop_idx
|
317 |
+
|
318 |
+
def get_file_id(self, filepath):
|
319 |
+
"""Given a filepath extract the DAPS file id.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
filepath (str): Path to an audio files in the DAPS dataset.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
file_id (str): DAPS file id of the form <participant_id>_<script_id>
|
326 |
+
file_set (str): The DAPS set to which the file belongs.
|
327 |
+
"""
|
328 |
+
file_id = os.path.basename(filepath).split("_")[:2]
|
329 |
+
file_id = "_".join(file_id)
|
330 |
+
return file_id
|
331 |
+
|
332 |
+
def get_file_set(self, filepath):
|
333 |
+
"""Given a filepath extract the DAPS file set name.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
filepath (str): Path to an audio files in the DAPS dataset.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
file_set (str): The DAPS set to which the file belongs.
|
340 |
+
"""
|
341 |
+
file_set = os.path.basename(filepath).split("_")[2:]
|
342 |
+
file_set = "_".join(file_set)
|
343 |
+
file_set = file_set.replace(f".{self.ext}", "")
|
344 |
+
return file_set
|
deepafx_st/data/proxy.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import glob
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
# from deepafx_st.plugins.channel import Channel
|
9 |
+
from deepafx_st.processors.processor import Processor
|
10 |
+
from deepafx_st.data.audio import AudioFile
|
11 |
+
import deepafx_st.utils as utils
|
12 |
+
|
13 |
+
|
14 |
+
class DSPProxyDataset(torch.utils.data.Dataset):
|
15 |
+
"""Class for generating input-output audio from Python DSP effects.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
input_dir (List[str]): List of paths to the directories containing input audio files.
|
19 |
+
processor (Processor): Processor object to create proxy of.
|
20 |
+
processor_type (str): Processor name.
|
21 |
+
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train"
|
22 |
+
buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0
|
23 |
+
Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers
|
24 |
+
buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000
|
25 |
+
length (int, optional): Number of samples to load for each example. Default: 65536
|
26 |
+
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000
|
27 |
+
ext (str, optional): Expected audio file extension. Default: "wav"
|
28 |
+
hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
input_dir: str,
|
34 |
+
processor: Processor,
|
35 |
+
processor_type: str,
|
36 |
+
subset="train",
|
37 |
+
length=65536,
|
38 |
+
buffer_size_gb=1.0,
|
39 |
+
buffer_reload_rate=1000,
|
40 |
+
half=False,
|
41 |
+
num_examples_per_epoch=10000,
|
42 |
+
ext="wav",
|
43 |
+
soft_clip=True,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.input_dir = input_dir
|
47 |
+
self.processor = processor
|
48 |
+
self.processor_type = processor_type
|
49 |
+
self.subset = subset
|
50 |
+
self.length = length
|
51 |
+
self.buffer_size_gb = buffer_size_gb
|
52 |
+
self.buffer_reload_rate = buffer_reload_rate
|
53 |
+
self.half = half
|
54 |
+
self.num_examples_per_epoch = num_examples_per_epoch
|
55 |
+
self.ext = ext
|
56 |
+
self.soft_clip = soft_clip
|
57 |
+
|
58 |
+
search_path = os.path.join(input_dir, f"*.{ext}")
|
59 |
+
self.input_filepaths = glob.glob(search_path)
|
60 |
+
self.input_filepaths = sorted(self.input_filepaths)
|
61 |
+
|
62 |
+
if len(self.input_filepaths) < 1:
|
63 |
+
raise RuntimeError(f"No files found in {input_dir}.")
|
64 |
+
|
65 |
+
# get training split
|
66 |
+
self.input_filepaths = utils.split_dataset(
|
67 |
+
self.input_filepaths, self.subset, 0.9
|
68 |
+
)
|
69 |
+
|
70 |
+
# get details about audio files
|
71 |
+
cnt = 0
|
72 |
+
self.input_files = {}
|
73 |
+
for input_filepath in tqdm(self.input_filepaths, ncols=80):
|
74 |
+
file_id = os.path.basename(input_filepath)
|
75 |
+
audio_file = AudioFile(
|
76 |
+
input_filepath,
|
77 |
+
preload=False,
|
78 |
+
half=half,
|
79 |
+
)
|
80 |
+
if audio_file.num_frames < self.length:
|
81 |
+
continue
|
82 |
+
self.input_files[file_id] = audio_file
|
83 |
+
self.sample_rate = self.input_files[file_id].sample_rate
|
84 |
+
cnt += 1
|
85 |
+
if cnt > 1000:
|
86 |
+
break
|
87 |
+
|
88 |
+
# some setup for iteratble loading of the dataset into RAM
|
89 |
+
self.items_since_load = self.buffer_reload_rate
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return self.num_examples_per_epoch
|
93 |
+
|
94 |
+
def load_audio_buffer(self):
|
95 |
+
self.input_files_loaded = {} # clear audio buffer
|
96 |
+
self.items_since_load = 0 # reset iteration counter
|
97 |
+
nbytes_loaded = 0 # counter for data in RAM
|
98 |
+
|
99 |
+
# different subset in each
|
100 |
+
random.shuffle(self.input_filepaths)
|
101 |
+
|
102 |
+
# load files into RAM
|
103 |
+
for input_filepath in self.input_filepaths:
|
104 |
+
file_id = os.path.basename(input_filepath)
|
105 |
+
audio_file = AudioFile(
|
106 |
+
input_filepath,
|
107 |
+
preload=True,
|
108 |
+
half=self.half,
|
109 |
+
)
|
110 |
+
|
111 |
+
if audio_file.num_frames < self.length:
|
112 |
+
continue
|
113 |
+
|
114 |
+
self.input_files_loaded[file_id] = audio_file
|
115 |
+
|
116 |
+
nbytes = audio_file.audio.element_size() * audio_file.audio.nelement()
|
117 |
+
nbytes_loaded += nbytes
|
118 |
+
|
119 |
+
if nbytes_loaded > self.buffer_size_gb * 1e9:
|
120 |
+
break
|
121 |
+
|
122 |
+
def __getitem__(self, _):
|
123 |
+
""" """
|
124 |
+
|
125 |
+
# increment counter
|
126 |
+
self.items_since_load += 1
|
127 |
+
|
128 |
+
# load next chunk into buffer if needed
|
129 |
+
if self.items_since_load > self.buffer_reload_rate:
|
130 |
+
self.load_audio_buffer()
|
131 |
+
|
132 |
+
rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys())
|
133 |
+
# use this random key to retrieve an input file
|
134 |
+
input_file = self.input_files_loaded[rand_input_file_id]
|
135 |
+
|
136 |
+
# load the audio data if needed
|
137 |
+
if not input_file.loaded:
|
138 |
+
input_file.load()
|
139 |
+
|
140 |
+
# get a random patch of size `self.length`
|
141 |
+
# start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length)
|
142 |
+
start_idx, stop_idx = utils.get_random_patch(input_file, self.length)
|
143 |
+
input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach()
|
144 |
+
|
145 |
+
# random scaling
|
146 |
+
input_audio /= input_audio.abs().max()
|
147 |
+
scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12
|
148 |
+
input_audio *= 10 ** (-scale_dB / 20.0)
|
149 |
+
|
150 |
+
# generate random parameters (uniform) over 0 to 1
|
151 |
+
params = torch.rand(self.processor.num_control_params)
|
152 |
+
|
153 |
+
# expects batch dim
|
154 |
+
# apply plugins with random parameters
|
155 |
+
if self.processor_type == "channel":
|
156 |
+
params[-1] = 0.5 # set makeup gain to 0dB
|
157 |
+
target_audio = self.processor(
|
158 |
+
input_audio.view(1, 1, -1),
|
159 |
+
params.view(1, -1),
|
160 |
+
)
|
161 |
+
target_audio = target_audio.view(1, -1)
|
162 |
+
elif self.processor_type == "peq":
|
163 |
+
target_audio = self.processor(
|
164 |
+
input_audio.view(1, 1, -1).numpy(),
|
165 |
+
params.view(1, -1).numpy(),
|
166 |
+
)
|
167 |
+
target_audio = torch.tensor(target_audio).view(1, -1)
|
168 |
+
elif self.processor_type == "comp":
|
169 |
+
params[-1] = 0.5 # set makeup gain to 0dB
|
170 |
+
target_audio = self.processor(
|
171 |
+
input_audio.view(1, 1, -1).numpy(),
|
172 |
+
params.view(1, -1).numpy(),
|
173 |
+
)
|
174 |
+
target_audio = torch.tensor(target_audio).view(1, -1)
|
175 |
+
|
176 |
+
# clip
|
177 |
+
if self.soft_clip:
|
178 |
+
# target_audio = target_audio.clamp(-2.0, 2.0)
|
179 |
+
target_audio = torch.tanh(target_audio / 2.0) * 2.0
|
180 |
+
|
181 |
+
return input_audio, target_audio, params
|
deepafx_st/data/style.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class StyleDataset(torch.utils.data.Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
audio_dir: str,
|
12 |
+
subset: str = "train",
|
13 |
+
sample_rate: int = 24000,
|
14 |
+
length: int = 131072,
|
15 |
+
) -> None:
|
16 |
+
super().__init__()
|
17 |
+
self.audio_dir = audio_dir
|
18 |
+
self.subset = subset
|
19 |
+
self.sample_rate = sample_rate
|
20 |
+
self.length = length
|
21 |
+
|
22 |
+
self.style_dirs = glob.glob(os.path.join(audio_dir, subset, "*"))
|
23 |
+
self.style_dirs = [sd for sd in self.style_dirs if os.path.isdir(sd)]
|
24 |
+
self.num_classes = len(self.style_dirs)
|
25 |
+
self.class_labels = {"broadcast" : 0, "telephone": 1, "neutral": 2, "bright": 3, "warm": 4}
|
26 |
+
|
27 |
+
self.examples = []
|
28 |
+
for n, style_dir in enumerate(self.style_dirs):
|
29 |
+
|
30 |
+
# get all files in style dir
|
31 |
+
style_filepaths = glob.glob(os.path.join(style_dir, "*.wav"))
|
32 |
+
style_name = os.path.basename(style_dir)
|
33 |
+
for style_filepath in tqdm(style_filepaths, ncols=120):
|
34 |
+
# load audio file
|
35 |
+
x, sr = torchaudio.load(style_filepath)
|
36 |
+
|
37 |
+
# sum to mono if needed
|
38 |
+
if x.shape[0] > 1:
|
39 |
+
x = x.mean(dim=0, keepdim=True)
|
40 |
+
|
41 |
+
# resample
|
42 |
+
if sr != self.sample_rate:
|
43 |
+
x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
|
44 |
+
|
45 |
+
# crop length after resample
|
46 |
+
if x.shape[-1] >= self.length:
|
47 |
+
x = x[...,:self.length]
|
48 |
+
|
49 |
+
# store example
|
50 |
+
example = (x, self.class_labels[style_name])
|
51 |
+
self.examples.append(example)
|
52 |
+
|
53 |
+
print(f"Loaded {len(self.examples)} examples for {subset} subset.")
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.examples)
|
57 |
+
|
58 |
+
def __getitem__(self, idx):
|
59 |
+
example = self.examples[idx]
|
60 |
+
x = example[0]
|
61 |
+
y = example[1]
|
62 |
+
return x, y
|
deepafx_st/metrics.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import auraloss
|
3 |
+
import resampy
|
4 |
+
import torchaudio
|
5 |
+
from pesq import pesq
|
6 |
+
import pyloudnorm as pyln
|
7 |
+
|
8 |
+
|
9 |
+
def crest_factor(x):
|
10 |
+
"""Compute the crest factor of waveform."""
|
11 |
+
|
12 |
+
peak, _ = x.abs().max(dim=-1)
|
13 |
+
rms = torch.sqrt((x ** 2).mean(dim=-1))
|
14 |
+
|
15 |
+
return 20 * torch.log(peak / rms.clamp(1e-8))
|
16 |
+
|
17 |
+
|
18 |
+
def rms_energy(x):
|
19 |
+
|
20 |
+
rms = torch.sqrt((x ** 2).mean(dim=-1))
|
21 |
+
|
22 |
+
return 20 * torch.log(rms.clamp(1e-8))
|
23 |
+
|
24 |
+
|
25 |
+
def spectral_centroid(x):
|
26 |
+
"""Compute the crest factor of waveform.
|
27 |
+
|
28 |
+
See: https://gist.github.com/endolith/359724
|
29 |
+
|
30 |
+
"""
|
31 |
+
|
32 |
+
spectrum = torch.fft.rfft(x).abs()
|
33 |
+
normalized_spectrum = spectrum / spectrum.sum()
|
34 |
+
normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1])
|
35 |
+
spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum)
|
36 |
+
|
37 |
+
return spectral_centroid
|
38 |
+
|
39 |
+
|
40 |
+
def loudness(x, sample_rate):
|
41 |
+
"""Compute the loudness in dB LUFS of waveform."""
|
42 |
+
meter = pyln.Meter(sample_rate)
|
43 |
+
|
44 |
+
# add stereo dim if needed
|
45 |
+
if x.shape[0] < 2:
|
46 |
+
x = x.repeat(2, 1)
|
47 |
+
|
48 |
+
return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy()))
|
49 |
+
|
50 |
+
|
51 |
+
class MelSpectralDistance(torch.nn.Module):
|
52 |
+
def __init__(self, sample_rate, length=65536):
|
53 |
+
super().__init__()
|
54 |
+
self.error = auraloss.freq.MelSTFTLoss(
|
55 |
+
sample_rate,
|
56 |
+
fft_size=length,
|
57 |
+
hop_size=length,
|
58 |
+
win_length=length,
|
59 |
+
w_sc=0,
|
60 |
+
w_log_mag=1,
|
61 |
+
w_lin_mag=1,
|
62 |
+
n_mels=128,
|
63 |
+
scale_invariance=False,
|
64 |
+
)
|
65 |
+
|
66 |
+
# I think scale invariance may not work well,
|
67 |
+
# since aspects of the phase may be considered?
|
68 |
+
|
69 |
+
def forward(self, input, target):
|
70 |
+
return self.error(input, target)
|
71 |
+
|
72 |
+
|
73 |
+
class PESQ(torch.nn.Module):
|
74 |
+
def __init__(self, sample_rate):
|
75 |
+
super().__init__()
|
76 |
+
self.sample_rate = sample_rate
|
77 |
+
|
78 |
+
def forward(self, input, target):
|
79 |
+
if self.sample_rate != 16000:
|
80 |
+
target = resampy.resample(
|
81 |
+
target.view(-1).numpy(),
|
82 |
+
self.sample_rate,
|
83 |
+
16000,
|
84 |
+
)
|
85 |
+
input = resampy.resample(
|
86 |
+
input.view(-1).numpy(),
|
87 |
+
self.sample_rate,
|
88 |
+
16000,
|
89 |
+
)
|
90 |
+
|
91 |
+
return pesq(
|
92 |
+
16000,
|
93 |
+
target,
|
94 |
+
input,
|
95 |
+
"wb",
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
class CrestFactorError(torch.nn.Module):
|
100 |
+
def __init__(self):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
def forward(self, input, target):
|
104 |
+
return torch.nn.functional.l1_loss(
|
105 |
+
crest_factor(input),
|
106 |
+
crest_factor(target),
|
107 |
+
).item()
|
108 |
+
|
109 |
+
|
110 |
+
class RMSEnergyError(torch.nn.Module):
|
111 |
+
def __init__(self):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
def forward(self, input, target):
|
115 |
+
return torch.nn.functional.l1_loss(
|
116 |
+
rms_energy(input),
|
117 |
+
rms_energy(target),
|
118 |
+
).item()
|
119 |
+
|
120 |
+
|
121 |
+
class SpectralCentroidError(torch.nn.Module):
|
122 |
+
def __init__(self, sample_rate, n_fft=2048, hop_length=512):
|
123 |
+
super().__init__()
|
124 |
+
|
125 |
+
self.spectral_centroid = torchaudio.transforms.SpectralCentroid(
|
126 |
+
sample_rate,
|
127 |
+
n_fft=n_fft,
|
128 |
+
hop_length=hop_length,
|
129 |
+
)
|
130 |
+
|
131 |
+
def forward(self, input, target):
|
132 |
+
return torch.nn.functional.l1_loss(
|
133 |
+
self.spectral_centroid(input + 1e-16).mean(),
|
134 |
+
self.spectral_centroid(target + 1e-16).mean(),
|
135 |
+
).item()
|
136 |
+
|
137 |
+
|
138 |
+
class LoudnessError(torch.nn.Module):
|
139 |
+
def __init__(self, sample_rate: int, peak_normalize: bool = False):
|
140 |
+
super().__init__()
|
141 |
+
self.sample_rate = sample_rate
|
142 |
+
self.peak_normalize = peak_normalize
|
143 |
+
|
144 |
+
def forward(self, input, target):
|
145 |
+
|
146 |
+
if self.peak_normalize:
|
147 |
+
# peak normalize
|
148 |
+
x = input / input.abs().max()
|
149 |
+
y = target / target.abs().max()
|
150 |
+
else:
|
151 |
+
x = input
|
152 |
+
y = target
|
153 |
+
|
154 |
+
return torch.nn.functional.l1_loss(
|
155 |
+
loudness(x.view(1, -1), self.sample_rate),
|
156 |
+
loudness(y.view(1, -1), self.sample_rate),
|
157 |
+
).item()
|
deepafx_st/models/baselines.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import scipy.signal
|
4 |
+
import numpy as np
|
5 |
+
import pyloudnorm as pyln
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from deepafx_st.processors.dsp.compressor import compressor
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
class BaselineEQ(torch.nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
ntaps: int = 63,
|
16 |
+
n_fft: int = 65536,
|
17 |
+
sample_rate: float = 44100,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.ntaps = ntaps
|
21 |
+
self.n_fft = n_fft
|
22 |
+
self.sample_rate = sample_rate
|
23 |
+
|
24 |
+
# compute the target spectrum
|
25 |
+
# print("Computing target spectrum...")
|
26 |
+
# self.target_spec, self.sm_target_spec = self.analyze_speech_dataset(filepaths)
|
27 |
+
# self.plot_spectrum(self.target_spec, filename="targetEQ")
|
28 |
+
# self.plot_spectrum(self.sm_target_spec, filename="targetEQsm")
|
29 |
+
|
30 |
+
def forward(self, x, y):
|
31 |
+
|
32 |
+
bs, ch, s = x.size()
|
33 |
+
|
34 |
+
x = x.view(bs * ch, -1)
|
35 |
+
y = y.view(bs * ch, -1)
|
36 |
+
|
37 |
+
in_spec = self.get_average_spectrum(x)
|
38 |
+
ref_spec = self.get_average_spectrum(y)
|
39 |
+
|
40 |
+
sm_in_spec = self.smooth_spectrum(in_spec)
|
41 |
+
sm_ref_spec = self.smooth_spectrum(ref_spec)
|
42 |
+
|
43 |
+
# self.plot_spectrum(in_spec, filename="inSpec")
|
44 |
+
# self.plot_spectrum(sm_in_spec, filename="inSpecsm")
|
45 |
+
|
46 |
+
# design inverse FIR filter to match target EQ
|
47 |
+
freqs = np.linspace(0, 1.0, num=(self.n_fft // 2) + 1)
|
48 |
+
response = sm_ref_spec / sm_in_spec
|
49 |
+
response[-1] = 0.0 # zero gain at nyquist
|
50 |
+
|
51 |
+
b = scipy.signal.firwin2(
|
52 |
+
self.ntaps,
|
53 |
+
freqs * (self.sample_rate / 2),
|
54 |
+
response,
|
55 |
+
fs=self.sample_rate,
|
56 |
+
)
|
57 |
+
|
58 |
+
# scale the coefficients for less intense filter
|
59 |
+
# clearb *= 0.5
|
60 |
+
|
61 |
+
# apply the filter
|
62 |
+
x_filt = scipy.signal.lfilter(b, [1.0], x.numpy())
|
63 |
+
x_filt = torch.tensor(x_filt.astype("float32"))
|
64 |
+
|
65 |
+
if False:
|
66 |
+
# plot the filter response
|
67 |
+
w, h = scipy.signal.freqz(b, fs=self.sample_rate, worN=response.shape[-1])
|
68 |
+
|
69 |
+
fig, ax1 = plt.subplots()
|
70 |
+
ax1.set_title("Digital filter frequency response")
|
71 |
+
ax1.plot(w, 20 * np.log10(abs(h + 1e-8)))
|
72 |
+
ax1.plot(w, 20 * np.log10(abs(response + 1e-8)))
|
73 |
+
|
74 |
+
ax1.set_xscale("log")
|
75 |
+
ax1.set_ylim([-12, 12])
|
76 |
+
plt.grid(c="lightgray")
|
77 |
+
plt.savefig(f"inverse.png")
|
78 |
+
|
79 |
+
x_filt_avg_spec = self.get_average_spectrum(x_filt)
|
80 |
+
sm_x_filt_avg_spec = self.smooth_spectrum(x_filt_avg_spec)
|
81 |
+
y_avg_spec = self.get_average_spectrum(y)
|
82 |
+
sm_y_avg_spec = self.smooth_spectrum(y_avg_spec)
|
83 |
+
compare = torch.stack(
|
84 |
+
[
|
85 |
+
torch.tensor(sm_in_spec),
|
86 |
+
torch.tensor(sm_x_filt_avg_spec),
|
87 |
+
torch.tensor(sm_ref_spec),
|
88 |
+
torch.tensor(sm_y_avg_spec),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
self.plot_multi_spectrum(
|
92 |
+
compare,
|
93 |
+
legend=["in", "out", "target curve", "actual target"],
|
94 |
+
filename="outSpec",
|
95 |
+
)
|
96 |
+
|
97 |
+
return x_filt
|
98 |
+
|
99 |
+
def analyze_speech_dataset(self, filepaths, peak=-3.0):
|
100 |
+
avg_spec = []
|
101 |
+
for filepath in tqdm(filepaths, ncols=80):
|
102 |
+
x, sr = torchaudio.load(filepath)
|
103 |
+
x /= x.abs().max()
|
104 |
+
x *= 10 ** (peak / 20.0)
|
105 |
+
avg_spec.append(self.get_average_spectrum(x))
|
106 |
+
avg_specs = torch.stack(avg_spec)
|
107 |
+
|
108 |
+
avg_spec = avg_specs.mean(dim=0).numpy()
|
109 |
+
avg_spec_std = avg_specs.std(dim=0).numpy()
|
110 |
+
|
111 |
+
# self.plot_multi_spectrum(avg_specs, filename="allTargetEQs")
|
112 |
+
# self.plot_spectrum_stats(avg_spec, avg_spec_std, filename="targetEQstats")
|
113 |
+
|
114 |
+
sm_avg_spec = self.smooth_spectrum(avg_spec)
|
115 |
+
|
116 |
+
return avg_spec, sm_avg_spec
|
117 |
+
|
118 |
+
def smooth_spectrum(self, H):
|
119 |
+
# apply Savgol filter for smoothed target curve
|
120 |
+
return scipy.signal.savgol_filter(H, 1025, 2)
|
121 |
+
|
122 |
+
def get_average_spectrum(self, x):
|
123 |
+
|
124 |
+
# x = x[:, : self.n_fft]
|
125 |
+
X = torch.stft(x, self.n_fft, return_complex=True, normalized=True)
|
126 |
+
# fft_size = self.next_power_of_2(x.shape[-1])
|
127 |
+
# X = torch.fft.rfft(x, n=fft_size)
|
128 |
+
|
129 |
+
X = X.abs() # convert to magnitude
|
130 |
+
X = X.mean(dim=-1).view(-1) # average across frames
|
131 |
+
|
132 |
+
return X
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def next_power_of_2(x):
|
136 |
+
return 1 if x == 0 else int(2 ** np.ceil(np.log2(x)))
|
137 |
+
|
138 |
+
def plot_multi_spectrum(self, Hs, legend=[], filename=None):
|
139 |
+
|
140 |
+
bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
|
141 |
+
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
|
142 |
+
|
143 |
+
fig, ax1 = plt.subplots()
|
144 |
+
|
145 |
+
for H in Hs:
|
146 |
+
ax1.plot(
|
147 |
+
freqs,
|
148 |
+
20 * np.log10(abs(H) + 1e-8),
|
149 |
+
)
|
150 |
+
|
151 |
+
plt.legend(legend)
|
152 |
+
|
153 |
+
# avg_spec = Hs.mean(dim=0).numpy()
|
154 |
+
# ax1.plot(freqs, 20 * np.log10(avg_spec), color="k", linewidth=2)
|
155 |
+
|
156 |
+
ax1.set_xscale("log")
|
157 |
+
ax1.set_ylim([-80, 0])
|
158 |
+
plt.grid(c="lightgray")
|
159 |
+
|
160 |
+
if filename is not None:
|
161 |
+
plt.savefig(f"{filename}.png")
|
162 |
+
|
163 |
+
def plot_spectrum_stats(self, H_mean, H_std, filename=None):
|
164 |
+
bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
|
165 |
+
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
|
166 |
+
|
167 |
+
fig, ax1 = plt.subplots()
|
168 |
+
ax1.plot(freqs, 20 * np.log10(H_mean))
|
169 |
+
ax1.plot(
|
170 |
+
freqs,
|
171 |
+
(20 * np.log10(H_mean)) + (20 * np.log10(H_std)),
|
172 |
+
linestyle="--",
|
173 |
+
color="k",
|
174 |
+
)
|
175 |
+
ax1.plot(
|
176 |
+
freqs,
|
177 |
+
(20 * np.log10(H_mean)) - (20 * np.log10(H_std)),
|
178 |
+
linestyle="--",
|
179 |
+
color="k",
|
180 |
+
)
|
181 |
+
|
182 |
+
ax1.set_xscale("log")
|
183 |
+
ax1.set_ylim([-80, 0])
|
184 |
+
plt.grid(c="lightgray")
|
185 |
+
|
186 |
+
if filename is not None:
|
187 |
+
plt.savefig(f"{filename}.png")
|
188 |
+
|
189 |
+
def plot_spectrum(self, H, legend=[], filename=None):
|
190 |
+
|
191 |
+
bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
|
192 |
+
freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)
|
193 |
+
|
194 |
+
fig, ax1 = plt.subplots()
|
195 |
+
ax1.plot(freqs, 20 * np.log10(H))
|
196 |
+
ax1.set_xscale("log")
|
197 |
+
ax1.set_ylim([-80, 0])
|
198 |
+
plt.grid(c="lightgray")
|
199 |
+
|
200 |
+
plt.legend(legend)
|
201 |
+
|
202 |
+
if filename is not None:
|
203 |
+
plt.savefig(f"{filename}.png")
|
204 |
+
|
205 |
+
|
206 |
+
class BaslineComp(torch.nn.Module):
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
sample_rate: float = 44100,
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
self.sample_rate = sample_rate
|
213 |
+
self.meter = pyln.Meter(sample_rate)
|
214 |
+
|
215 |
+
def forward(self, x, y):
|
216 |
+
|
217 |
+
x_lufs = self.meter.integrated_loudness(x.view(-1).numpy())
|
218 |
+
y_lufs = self.meter.integrated_loudness(y.view(-1).numpy())
|
219 |
+
|
220 |
+
delta_lufs = y_lufs - x_lufs
|
221 |
+
|
222 |
+
threshold = 0.0
|
223 |
+
x_comp = x
|
224 |
+
x_comp_new = x
|
225 |
+
while delta_lufs > 0.5 and threshold > -80.0:
|
226 |
+
x_comp = x_comp_new # use the last setting
|
227 |
+
x_comp_new = compressor(
|
228 |
+
x.view(-1).numpy(),
|
229 |
+
self.sample_rate,
|
230 |
+
threshold=threshold,
|
231 |
+
ratio=3,
|
232 |
+
attack_time=0.001,
|
233 |
+
release_time=0.05,
|
234 |
+
knee_dB=6.0,
|
235 |
+
makeup_gain_dB=0.0,
|
236 |
+
)
|
237 |
+
x_comp_new = torch.tensor(x_comp_new)
|
238 |
+
x_comp_new /= x_comp_new.abs().max()
|
239 |
+
x_comp_new *= 10 ** (-12.0 / 20)
|
240 |
+
x_lufs = self.meter.integrated_loudness(x_comp_new.view(-1).numpy())
|
241 |
+
delta_lufs = y_lufs - x_lufs
|
242 |
+
threshold -= 0.5
|
243 |
+
|
244 |
+
return x_comp.view(1, 1, -1)
|
245 |
+
|
246 |
+
|
247 |
+
class BaselineEQAndComp(torch.nn.Module):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
ntaps=63,
|
251 |
+
n_fft=65536,
|
252 |
+
sample_rate=44100,
|
253 |
+
block_size=1024,
|
254 |
+
plugin_config=None,
|
255 |
+
):
|
256 |
+
super().__init__()
|
257 |
+
self.eq = BaselineEQ(ntaps, n_fft, sample_rate)
|
258 |
+
self.comp = BaslineComp(sample_rate)
|
259 |
+
|
260 |
+
def forward(self, x, y):
|
261 |
+
|
262 |
+
with torch.inference_mode():
|
263 |
+
x /= x.abs().max()
|
264 |
+
y /= y.abs().max()
|
265 |
+
x *= 10 ** (-12.0 / 20)
|
266 |
+
y *= 10 ** (-12.0 / 20)
|
267 |
+
|
268 |
+
x = self.eq(x, y)
|
269 |
+
|
270 |
+
x /= x.abs().max()
|
271 |
+
y /= y.abs().max()
|
272 |
+
x *= 10 ** (-12.0 / 20)
|
273 |
+
y *= 10 ** (-12.0 / 20)
|
274 |
+
|
275 |
+
x = self.comp(x, y)
|
276 |
+
|
277 |
+
x /= x.abs().max()
|
278 |
+
x *= 10 ** (-12.0 / 20)
|
279 |
+
|
280 |
+
return x
|
deepafx_st/models/controller.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class StyleTransferController(torch.nn.Module):
|
4 |
+
def __init__(
|
5 |
+
self,
|
6 |
+
num_control_params,
|
7 |
+
edim,
|
8 |
+
hidden_dim=256,
|
9 |
+
agg_method="mlp",
|
10 |
+
):
|
11 |
+
"""Plugin parameter controller module to map from input to target style.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_control_params (int): Number of plugin parameters to predicted.
|
15 |
+
edim (int): Size of the encoder representations.
|
16 |
+
hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256
|
17 |
+
agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp"
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.num_control_params = num_control_params
|
21 |
+
self.edim = edim
|
22 |
+
self.hidden_dim = hidden_dim
|
23 |
+
self.agg_method = agg_method
|
24 |
+
|
25 |
+
if agg_method == "conv":
|
26 |
+
self.agg = torch.nn.Conv1d(
|
27 |
+
2,
|
28 |
+
1,
|
29 |
+
kernel_size=129,
|
30 |
+
stride=1,
|
31 |
+
padding="same",
|
32 |
+
bias=False,
|
33 |
+
)
|
34 |
+
mlp_in_dim = edim
|
35 |
+
elif agg_method == "linear":
|
36 |
+
self.agg = torch.nn.Linear(edim * 2, edim)
|
37 |
+
elif agg_method == "mlp":
|
38 |
+
self.agg = None
|
39 |
+
mlp_in_dim = edim * 2
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Invalid agg_method = {self.agg_method}.")
|
42 |
+
|
43 |
+
self.mlp = torch.nn.Sequential(
|
44 |
+
torch.nn.Linear(mlp_in_dim, hidden_dim),
|
45 |
+
torch.nn.LeakyReLU(0.01),
|
46 |
+
torch.nn.Linear(hidden_dim, hidden_dim),
|
47 |
+
torch.nn.LeakyReLU(0.01),
|
48 |
+
torch.nn.Linear(hidden_dim, num_control_params),
|
49 |
+
torch.nn.Sigmoid(), # normalize between 0 and 1
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, e_x, e_y, z=None):
|
53 |
+
"""Forward pass to generate plugin parameters.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
e_x (tensor): Input signal embedding of shape (batch, edim)
|
57 |
+
e_y (tensor): Target signal embedding of shape (batch, edim)
|
58 |
+
Returns:
|
59 |
+
p (tensor): Estimated control parameters of shape (batch, num_control_params)
|
60 |
+
"""
|
61 |
+
|
62 |
+
# use learnable projection
|
63 |
+
if self.agg_method == "conv":
|
64 |
+
e_xy = torch.stack((e_x, e_y), dim=1) # concat on channel dim
|
65 |
+
e_xy = self.agg(e_xy)
|
66 |
+
elif self.agg_method == "linear":
|
67 |
+
e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim
|
68 |
+
e_xy = self.agg(e_xy)
|
69 |
+
else:
|
70 |
+
e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim
|
71 |
+
|
72 |
+
# pass through MLP to project to control parametesr
|
73 |
+
p = self.mlp(e_xy.squeeze(1))
|
74 |
+
|
75 |
+
return p
|
deepafx_st/models/efficient_net/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
deepafx_st/models/efficient_net/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.7.1"
|
2 |
+
from .model import EfficientNet, VALID_MODELS
|
3 |
+
from .utils import (
|
4 |
+
GlobalParams,
|
5 |
+
BlockArgs,
|
6 |
+
BlockDecoder,
|
7 |
+
efficientnet,
|
8 |
+
get_model_params,
|
9 |
+
)
|
deepafx_st/models/efficient_net/model.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""model.py - Model and module class for EfficientNet.
|
2 |
+
They are built to mirror those in the official TensorFlow implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# Author: lukemelas (github username)
|
6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
7 |
+
# With adjustments and added comments by workingcoder (github username).
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from .utils import (
|
13 |
+
round_filters,
|
14 |
+
round_repeats,
|
15 |
+
drop_connect,
|
16 |
+
get_same_padding_conv2d,
|
17 |
+
get_model_params,
|
18 |
+
efficientnet_params,
|
19 |
+
load_pretrained_weights,
|
20 |
+
Swish,
|
21 |
+
MemoryEfficientSwish,
|
22 |
+
calculate_output_image_size
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
VALID_MODELS = (
|
27 |
+
'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
|
28 |
+
'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
|
29 |
+
'efficientnet-b8',
|
30 |
+
|
31 |
+
# Support the construction of 'efficientnet-l2' without pretrained weights
|
32 |
+
'efficientnet-l2'
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class MBConvBlock(nn.Module):
|
37 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
41 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
42 |
+
image_size (tuple or list): [image_height, image_width].
|
43 |
+
|
44 |
+
References:
|
45 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
46 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
47 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, block_args, global_params, image_size=None):
|
51 |
+
super().__init__()
|
52 |
+
self._block_args = block_args
|
53 |
+
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
|
54 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
55 |
+
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
|
56 |
+
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
|
57 |
+
|
58 |
+
# Expansion phase (Inverted Bottleneck)
|
59 |
+
inp = self._block_args.input_filters # number of input channels
|
60 |
+
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
|
61 |
+
if self._block_args.expand_ratio != 1:
|
62 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
63 |
+
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
|
64 |
+
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
65 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
66 |
+
|
67 |
+
# Depthwise convolution phase
|
68 |
+
k = self._block_args.kernel_size
|
69 |
+
s = self._block_args.stride
|
70 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
71 |
+
self._depthwise_conv = Conv2d(
|
72 |
+
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
|
73 |
+
kernel_size=k, stride=s, bias=False)
|
74 |
+
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
|
75 |
+
image_size = calculate_output_image_size(image_size, s)
|
76 |
+
|
77 |
+
# Squeeze and Excitation layer, if desired
|
78 |
+
if self.has_se:
|
79 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
80 |
+
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
|
81 |
+
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
|
82 |
+
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
|
83 |
+
|
84 |
+
# Pointwise convolution phase
|
85 |
+
final_oup = self._block_args.output_filters
|
86 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
87 |
+
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
|
88 |
+
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
|
89 |
+
self._swish = MemoryEfficientSwish()
|
90 |
+
|
91 |
+
def forward(self, inputs, drop_connect_rate=None):
|
92 |
+
"""MBConvBlock's forward function.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
inputs (tensor): Input tensor.
|
96 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Output of this block after processing.
|
100 |
+
"""
|
101 |
+
|
102 |
+
# Expansion and Depthwise Convolution
|
103 |
+
x = inputs
|
104 |
+
if self._block_args.expand_ratio != 1:
|
105 |
+
x = self._expand_conv(inputs)
|
106 |
+
x = self._bn0(x)
|
107 |
+
x = self._swish(x)
|
108 |
+
|
109 |
+
x = self._depthwise_conv(x)
|
110 |
+
x = self._bn1(x)
|
111 |
+
x = self._swish(x)
|
112 |
+
|
113 |
+
# Squeeze and Excitation
|
114 |
+
if self.has_se:
|
115 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
116 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
117 |
+
x_squeezed = self._swish(x_squeezed)
|
118 |
+
x_squeezed = self._se_expand(x_squeezed)
|
119 |
+
x = torch.sigmoid(x_squeezed) * x
|
120 |
+
|
121 |
+
# Pointwise Convolution
|
122 |
+
x = self._project_conv(x)
|
123 |
+
x = self._bn2(x)
|
124 |
+
|
125 |
+
# Skip connection and drop connect
|
126 |
+
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
|
127 |
+
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
|
128 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
129 |
+
if drop_connect_rate:
|
130 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
131 |
+
x = x + inputs # skip connection
|
132 |
+
return x
|
133 |
+
|
134 |
+
def set_swish(self, memory_efficient=True):
|
135 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
136 |
+
|
137 |
+
Args:
|
138 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
139 |
+
"""
|
140 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
141 |
+
|
142 |
+
|
143 |
+
class EfficientNet(nn.Module):
|
144 |
+
"""EfficientNet model.
|
145 |
+
Most easily loaded with the .from_name or .from_pretrained methods.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
|
149 |
+
global_params (namedtuple): A set of GlobalParams shared between blocks.
|
150 |
+
|
151 |
+
References:
|
152 |
+
[1] https://arxiv.org/abs/1905.11946 (EfficientNet)
|
153 |
+
|
154 |
+
Example:
|
155 |
+
>>> import torch
|
156 |
+
>>> from efficientnet.model import EfficientNet
|
157 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
158 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
159 |
+
>>> model.eval()
|
160 |
+
>>> outputs = model(inputs)
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, blocks_args=None, global_params=None):
|
164 |
+
super().__init__()
|
165 |
+
assert isinstance(blocks_args, list), 'blocks_args should be a list'
|
166 |
+
assert len(blocks_args) > 0, 'block args must be greater than 0'
|
167 |
+
self._global_params = global_params
|
168 |
+
self._blocks_args = blocks_args
|
169 |
+
|
170 |
+
# Batch norm parameters
|
171 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
172 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
173 |
+
|
174 |
+
# Get stem static or dynamic convolution depending on image size
|
175 |
+
image_size = global_params.image_size
|
176 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
177 |
+
|
178 |
+
# Stem
|
179 |
+
in_channels = 3 # rgb
|
180 |
+
out_channels = round_filters(32, self._global_params) # number of output channels
|
181 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
182 |
+
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
183 |
+
image_size = calculate_output_image_size(image_size, 2)
|
184 |
+
|
185 |
+
# Build blocks
|
186 |
+
self._blocks = nn.ModuleList([])
|
187 |
+
for block_args in self._blocks_args:
|
188 |
+
|
189 |
+
# Update block input and output filters based on depth multiplier.
|
190 |
+
block_args = block_args._replace(
|
191 |
+
input_filters=round_filters(block_args.input_filters, self._global_params),
|
192 |
+
output_filters=round_filters(block_args.output_filters, self._global_params),
|
193 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params)
|
194 |
+
)
|
195 |
+
|
196 |
+
# The first block needs to take care of stride and filter size increase.
|
197 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
198 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
199 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
200 |
+
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
|
201 |
+
for _ in range(block_args.num_repeat - 1):
|
202 |
+
self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
|
203 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
204 |
+
|
205 |
+
# Head
|
206 |
+
in_channels = block_args.output_filters # output of final block
|
207 |
+
out_channels = round_filters(1280, self._global_params)
|
208 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
209 |
+
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
210 |
+
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
|
211 |
+
|
212 |
+
# Final linear layer
|
213 |
+
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
|
214 |
+
if self._global_params.include_top:
|
215 |
+
self._dropout = nn.Dropout(self._global_params.dropout_rate)
|
216 |
+
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
|
217 |
+
|
218 |
+
# set activation to memory efficient swish by default
|
219 |
+
self._swish = MemoryEfficientSwish()
|
220 |
+
|
221 |
+
def set_swish(self, memory_efficient=True):
|
222 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
223 |
+
|
224 |
+
Args:
|
225 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
226 |
+
"""
|
227 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
228 |
+
for block in self._blocks:
|
229 |
+
block.set_swish(memory_efficient)
|
230 |
+
|
231 |
+
def extract_endpoints(self, inputs):
|
232 |
+
"""Use convolution layer to extract features
|
233 |
+
from reduction levels i in [1, 2, 3, 4, 5].
|
234 |
+
|
235 |
+
Args:
|
236 |
+
inputs (tensor): Input tensor.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Dictionary of last intermediate features
|
240 |
+
with reduction levels i in [1, 2, 3, 4, 5].
|
241 |
+
Example:
|
242 |
+
>>> import torch
|
243 |
+
>>> from efficientnet.model import EfficientNet
|
244 |
+
>>> inputs = torch.rand(1, 3, 224, 224)
|
245 |
+
>>> model = EfficientNet.from_pretrained('efficientnet-b0')
|
246 |
+
>>> endpoints = model.extract_endpoints(inputs)
|
247 |
+
>>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
|
248 |
+
>>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
|
249 |
+
>>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
|
250 |
+
>>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
|
251 |
+
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
|
252 |
+
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
|
253 |
+
"""
|
254 |
+
endpoints = dict()
|
255 |
+
|
256 |
+
# Stem
|
257 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
258 |
+
prev_x = x
|
259 |
+
|
260 |
+
# Blocks
|
261 |
+
for idx, block in enumerate(self._blocks):
|
262 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
263 |
+
if drop_connect_rate:
|
264 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
265 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
266 |
+
if prev_x.size(2) > x.size(2):
|
267 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
|
268 |
+
elif idx == len(self._blocks) - 1:
|
269 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
270 |
+
prev_x = x
|
271 |
+
|
272 |
+
# Head
|
273 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
274 |
+
endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
|
275 |
+
|
276 |
+
return endpoints
|
277 |
+
|
278 |
+
def extract_features(self, inputs):
|
279 |
+
"""use convolution layer to extract feature .
|
280 |
+
|
281 |
+
Args:
|
282 |
+
inputs (tensor): Input tensor.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
Output of the final convolution
|
286 |
+
layer in the efficientnet model.
|
287 |
+
"""
|
288 |
+
# Stem
|
289 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
290 |
+
|
291 |
+
# Blocks
|
292 |
+
for idx, block in enumerate(self._blocks):
|
293 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
294 |
+
if drop_connect_rate:
|
295 |
+
drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
|
296 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
297 |
+
|
298 |
+
# Head
|
299 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
300 |
+
|
301 |
+
return x
|
302 |
+
|
303 |
+
def forward(self, inputs):
|
304 |
+
"""EfficientNet's forward function.
|
305 |
+
Calls extract_features to extract features, applies final linear layer, and returns logits.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
inputs (tensor): Input tensor.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
Output of this model after processing.
|
312 |
+
"""
|
313 |
+
# Convolution layers
|
314 |
+
x = self.extract_features(inputs)
|
315 |
+
# Pooling and final linear layer
|
316 |
+
x = self._avg_pooling(x)
|
317 |
+
if self._global_params.include_top:
|
318 |
+
x = x.flatten(start_dim=1)
|
319 |
+
x = self._dropout(x)
|
320 |
+
x = self._fc(x)
|
321 |
+
return x
|
322 |
+
|
323 |
+
@classmethod
|
324 |
+
def from_name(cls, model_name, in_channels=3, **override_params):
|
325 |
+
"""Create an efficientnet model according to name.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
model_name (str): Name for efficientnet.
|
329 |
+
in_channels (int): Input data's channel number.
|
330 |
+
override_params (other key word params):
|
331 |
+
Params to override model's global_params.
|
332 |
+
Optional key:
|
333 |
+
'width_coefficient', 'depth_coefficient',
|
334 |
+
'image_size', 'dropout_rate',
|
335 |
+
'num_classes', 'batch_norm_momentum',
|
336 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
337 |
+
'depth_divisor', 'min_depth'
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
An efficientnet model.
|
341 |
+
"""
|
342 |
+
cls._check_model_name_is_valid(model_name)
|
343 |
+
blocks_args, global_params = get_model_params(model_name, override_params)
|
344 |
+
model = cls(blocks_args, global_params)
|
345 |
+
model._change_in_channels(in_channels)
|
346 |
+
return model
|
347 |
+
|
348 |
+
@classmethod
|
349 |
+
def from_pretrained(cls, model_name, weights_path=None, advprop=False,
|
350 |
+
in_channels=3, num_classes=1000, **override_params):
|
351 |
+
"""Create an efficientnet model according to name.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
model_name (str): Name for efficientnet.
|
355 |
+
weights_path (None or str):
|
356 |
+
str: path to pretrained weights file on the local disk.
|
357 |
+
None: use pretrained weights downloaded from the Internet.
|
358 |
+
advprop (bool):
|
359 |
+
Whether to load pretrained weights
|
360 |
+
trained with advprop (valid when weights_path is None).
|
361 |
+
in_channels (int): Input data's channel number.
|
362 |
+
num_classes (int):
|
363 |
+
Number of categories for classification.
|
364 |
+
It controls the output size for final linear layer.
|
365 |
+
override_params (other key word params):
|
366 |
+
Params to override model's global_params.
|
367 |
+
Optional key:
|
368 |
+
'width_coefficient', 'depth_coefficient',
|
369 |
+
'image_size', 'dropout_rate',
|
370 |
+
'batch_norm_momentum',
|
371 |
+
'batch_norm_epsilon', 'drop_connect_rate',
|
372 |
+
'depth_divisor', 'min_depth'
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
A pretrained efficientnet model.
|
376 |
+
"""
|
377 |
+
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
|
378 |
+
load_pretrained_weights(model, model_name, weights_path=weights_path,
|
379 |
+
load_fc=(num_classes == 1000), advprop=advprop)
|
380 |
+
model._change_in_channels(in_channels)
|
381 |
+
return model
|
382 |
+
|
383 |
+
@classmethod
|
384 |
+
def get_image_size(cls, model_name):
|
385 |
+
"""Get the input image size for a given efficientnet model.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
model_name (str): Name for efficientnet.
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
Input image size (resolution).
|
392 |
+
"""
|
393 |
+
cls._check_model_name_is_valid(model_name)
|
394 |
+
_, _, res, _ = efficientnet_params(model_name)
|
395 |
+
return res
|
396 |
+
|
397 |
+
@classmethod
|
398 |
+
def _check_model_name_is_valid(cls, model_name):
|
399 |
+
"""Validates model name.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
model_name (str): Name for efficientnet.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
bool: Is a valid name or not.
|
406 |
+
"""
|
407 |
+
if model_name not in VALID_MODELS:
|
408 |
+
raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
|
409 |
+
|
410 |
+
def _change_in_channels(self, in_channels):
|
411 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
in_channels (int): Input data's channel number.
|
415 |
+
"""
|
416 |
+
if in_channels != 3:
|
417 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
418 |
+
out_channels = round_filters(32, self._global_params)
|
419 |
+
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
|
deepafx_st/models/efficient_net/utils.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""utils.py - Helper functions for building the model and for loading model parameters.
|
2 |
+
These helper functions are built to mirror those in the official TensorFlow implementation.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# Author: lukemelas (github username)
|
6 |
+
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
7 |
+
# With adjustments and added comments by workingcoder (github username).
|
8 |
+
|
9 |
+
import re
|
10 |
+
import math
|
11 |
+
import collections
|
12 |
+
from functools import partial
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.utils import model_zoo
|
17 |
+
|
18 |
+
|
19 |
+
################################################################################
|
20 |
+
# Help functions for model architecture
|
21 |
+
################################################################################
|
22 |
+
|
23 |
+
# GlobalParams and BlockArgs: Two namedtuples
|
24 |
+
# Swish and MemoryEfficientSwish: Two implementations of the method
|
25 |
+
# round_filters and round_repeats:
|
26 |
+
# Functions to calculate params for scaling model width and depth ! ! !
|
27 |
+
# get_width_and_height_from_size and calculate_output_image_size
|
28 |
+
# drop_connect: A structural design
|
29 |
+
# get_same_padding_conv2d:
|
30 |
+
# Conv2dDynamicSamePadding
|
31 |
+
# Conv2dStaticSamePadding
|
32 |
+
# get_same_padding_maxPool2d:
|
33 |
+
# MaxPool2dDynamicSamePadding
|
34 |
+
# MaxPool2dStaticSamePadding
|
35 |
+
# It's an additional function, not used in EfficientNet,
|
36 |
+
# but can be used in other model (such as EfficientDet).
|
37 |
+
|
38 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
39 |
+
GlobalParams = collections.namedtuple('GlobalParams', [
|
40 |
+
'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
|
41 |
+
'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
|
42 |
+
'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])
|
43 |
+
|
44 |
+
# Parameters for an individual model block
|
45 |
+
BlockArgs = collections.namedtuple('BlockArgs', [
|
46 |
+
'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
|
47 |
+
'input_filters', 'output_filters', 'se_ratio', 'id_skip'])
|
48 |
+
|
49 |
+
# Set GlobalParams and BlockArgs's defaults
|
50 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
51 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
52 |
+
|
53 |
+
# Swish activation function
|
54 |
+
if hasattr(nn, 'SiLU'):
|
55 |
+
Swish = nn.SiLU
|
56 |
+
else:
|
57 |
+
# For compatibility with old PyTorch versions
|
58 |
+
class Swish(nn.Module):
|
59 |
+
def forward(self, x):
|
60 |
+
return x * torch.sigmoid(x)
|
61 |
+
|
62 |
+
|
63 |
+
# A memory-efficient implementation of Swish function
|
64 |
+
class SwishImplementation(torch.autograd.Function):
|
65 |
+
@staticmethod
|
66 |
+
def forward(ctx, i):
|
67 |
+
result = i * torch.sigmoid(i)
|
68 |
+
ctx.save_for_backward(i)
|
69 |
+
return result
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def backward(ctx, grad_output):
|
73 |
+
i = ctx.saved_tensors[0]
|
74 |
+
sigmoid_i = torch.sigmoid(i)
|
75 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
76 |
+
|
77 |
+
|
78 |
+
class MemoryEfficientSwish(nn.Module):
|
79 |
+
def forward(self, x):
|
80 |
+
return SwishImplementation.apply(x)
|
81 |
+
|
82 |
+
|
83 |
+
def round_filters(filters, global_params):
|
84 |
+
"""Calculate and round number of filters based on width multiplier.
|
85 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
filters (int): Filters number to be calculated.
|
89 |
+
global_params (namedtuple): Global params of the model.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
new_filters: New filters number after calculating.
|
93 |
+
"""
|
94 |
+
multiplier = global_params.width_coefficient
|
95 |
+
if not multiplier:
|
96 |
+
return filters
|
97 |
+
# TODO: modify the params names.
|
98 |
+
# maybe the names (width_divisor,min_width)
|
99 |
+
# are more suitable than (depth_divisor,min_depth).
|
100 |
+
divisor = global_params.depth_divisor
|
101 |
+
min_depth = global_params.min_depth
|
102 |
+
filters *= multiplier
|
103 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
104 |
+
# follow the formula transferred from official TensorFlow implementation
|
105 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
106 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
107 |
+
new_filters += divisor
|
108 |
+
return int(new_filters)
|
109 |
+
|
110 |
+
|
111 |
+
def round_repeats(repeats, global_params):
|
112 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
113 |
+
Use depth_coefficient of global_params.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
repeats (int): num_repeat to be calculated.
|
117 |
+
global_params (namedtuple): Global params of the model.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
new repeat: New repeat number after calculating.
|
121 |
+
"""
|
122 |
+
multiplier = global_params.depth_coefficient
|
123 |
+
if not multiplier:
|
124 |
+
return repeats
|
125 |
+
# follow the formula transferred from official TensorFlow implementation
|
126 |
+
return int(math.ceil(multiplier * repeats))
|
127 |
+
|
128 |
+
|
129 |
+
def drop_connect(inputs, p, training):
|
130 |
+
"""Drop connect.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
input (tensor: BCWH): Input of this structure.
|
134 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
135 |
+
training (bool): The running mode.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
output: Output after drop connection.
|
139 |
+
"""
|
140 |
+
assert 0 <= p <= 1, 'p must be in range of [0,1]'
|
141 |
+
|
142 |
+
if not training:
|
143 |
+
return inputs
|
144 |
+
|
145 |
+
batch_size = inputs.shape[0]
|
146 |
+
keep_prob = 1 - p
|
147 |
+
|
148 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
149 |
+
random_tensor = keep_prob
|
150 |
+
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
|
151 |
+
binary_tensor = torch.floor(random_tensor)
|
152 |
+
|
153 |
+
output = inputs / keep_prob * binary_tensor
|
154 |
+
return output
|
155 |
+
|
156 |
+
|
157 |
+
def get_width_and_height_from_size(x):
|
158 |
+
"""Obtain height and width from x.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
x (int, tuple or list): Data size.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
size: A tuple or list (H,W).
|
165 |
+
"""
|
166 |
+
if isinstance(x, int):
|
167 |
+
return x, x
|
168 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
169 |
+
return x
|
170 |
+
else:
|
171 |
+
raise TypeError()
|
172 |
+
|
173 |
+
|
174 |
+
def calculate_output_image_size(input_image_size, stride):
|
175 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
176 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
input_image_size (int, tuple or list): Size of input image.
|
180 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
output_image_size: A list [H,W].
|
184 |
+
"""
|
185 |
+
if input_image_size is None:
|
186 |
+
return None
|
187 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
188 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
189 |
+
image_height = int(math.ceil(image_height / stride))
|
190 |
+
image_width = int(math.ceil(image_width / stride))
|
191 |
+
return [image_height, image_width]
|
192 |
+
|
193 |
+
|
194 |
+
# Note:
|
195 |
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
196 |
+
# Only when stride equals 1, can the output size be the same as input size.
|
197 |
+
# Don't be confused by their function names ! ! !
|
198 |
+
|
199 |
+
def get_same_padding_conv2d(image_size=None):
|
200 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
201 |
+
Static padding is necessary for ONNX exporting of models.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
image_size (int or tuple): Size of the image.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
208 |
+
"""
|
209 |
+
if image_size is None:
|
210 |
+
return Conv2dDynamicSamePadding
|
211 |
+
else:
|
212 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
213 |
+
|
214 |
+
|
215 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
216 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
217 |
+
The padding is operated in forward function by calculating dynamically.
|
218 |
+
"""
|
219 |
+
|
220 |
+
# Tips for 'SAME' mode padding.
|
221 |
+
# Given the following:
|
222 |
+
# i: width or height
|
223 |
+
# s: stride
|
224 |
+
# k: kernel size
|
225 |
+
# d: dilation
|
226 |
+
# p: padding
|
227 |
+
# Output after Conv2d:
|
228 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
229 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
230 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
231 |
+
|
232 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
|
233 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
234 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
ih, iw = x.size()[-2:]
|
238 |
+
kh, kw = self.weight.size()[-2:]
|
239 |
+
sh, sw = self.stride
|
240 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! !
|
241 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
242 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
243 |
+
if pad_h > 0 or pad_w > 0:
|
244 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
245 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
246 |
+
|
247 |
+
|
248 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
249 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
250 |
+
The padding mudule is calculated in construction function, then used in forward.
|
251 |
+
"""
|
252 |
+
|
253 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
254 |
+
|
255 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs):
|
256 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
257 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
258 |
+
|
259 |
+
# Calculate padding based on image size and save it
|
260 |
+
assert image_size is not None
|
261 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
262 |
+
kh, kw = self.weight.size()[-2:]
|
263 |
+
sh, sw = self.stride
|
264 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
265 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
266 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
267 |
+
if pad_h > 0 or pad_w > 0:
|
268 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2,
|
269 |
+
pad_h // 2, pad_h - pad_h // 2))
|
270 |
+
else:
|
271 |
+
self.static_padding = nn.Identity()
|
272 |
+
|
273 |
+
def forward(self, x):
|
274 |
+
x = self.static_padding(x)
|
275 |
+
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
def get_same_padding_maxPool2d(image_size=None):
|
280 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
281 |
+
Static padding is necessary for ONNX exporting of models.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
image_size (int or tuple): Size of the image.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
288 |
+
"""
|
289 |
+
if image_size is None:
|
290 |
+
return MaxPool2dDynamicSamePadding
|
291 |
+
else:
|
292 |
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
293 |
+
|
294 |
+
|
295 |
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
296 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
297 |
+
The padding is operated in forward function by calculating dynamically.
|
298 |
+
"""
|
299 |
+
|
300 |
+
def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False):
|
301 |
+
super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode)
|
302 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
303 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
304 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
ih, iw = x.size()[-2:]
|
308 |
+
kh, kw = self.kernel_size
|
309 |
+
sh, sw = self.stride
|
310 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
311 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
312 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
313 |
+
if pad_h > 0 or pad_w > 0:
|
314 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
315 |
+
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
316 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
317 |
+
|
318 |
+
|
319 |
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
320 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
321 |
+
The padding mudule is calculated in construction function, then used in forward.
|
322 |
+
"""
|
323 |
+
|
324 |
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
325 |
+
super().__init__(kernel_size, stride, **kwargs)
|
326 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
327 |
+
self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size
|
328 |
+
self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
329 |
+
|
330 |
+
# Calculate padding based on image size and save it
|
331 |
+
assert image_size is not None
|
332 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
333 |
+
kh, kw = self.kernel_size
|
334 |
+
sh, sw = self.stride
|
335 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
336 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
337 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
338 |
+
if pad_h > 0 or pad_w > 0:
|
339 |
+
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
|
340 |
+
else:
|
341 |
+
self.static_padding = nn.Identity()
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
x = self.static_padding(x)
|
345 |
+
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding,
|
346 |
+
self.dilation, self.ceil_mode, self.return_indices)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
################################################################################
|
351 |
+
# Helper functions for loading model params
|
352 |
+
################################################################################
|
353 |
+
|
354 |
+
# BlockDecoder: A Class for encoding and decoding BlockArgs
|
355 |
+
# efficientnet_params: A function to query compound coefficient
|
356 |
+
# get_model_params and efficientnet:
|
357 |
+
# Functions to get BlockArgs and GlobalParams for efficientnet
|
358 |
+
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
|
359 |
+
# load_pretrained_weights: A function to load pretrained weights
|
360 |
+
|
361 |
+
class BlockDecoder(object):
|
362 |
+
"""Block Decoder for readability,
|
363 |
+
straight from the official TensorFlow repository.
|
364 |
+
"""
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
def _decode_block_string(block_string):
|
368 |
+
"""Get a block through a string notation of arguments.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
block_string (str): A string notation of arguments.
|
372 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
376 |
+
"""
|
377 |
+
assert isinstance(block_string, str)
|
378 |
+
|
379 |
+
ops = block_string.split('_')
|
380 |
+
options = {}
|
381 |
+
for op in ops:
|
382 |
+
splits = re.split(r'(\d.*)', op)
|
383 |
+
if len(splits) >= 2:
|
384 |
+
key, value = splits[:2]
|
385 |
+
options[key] = value
|
386 |
+
|
387 |
+
# Check stride
|
388 |
+
assert (('s' in options and len(options['s']) == 1) or
|
389 |
+
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
|
390 |
+
|
391 |
+
return BlockArgs(
|
392 |
+
num_repeat=int(options['r']),
|
393 |
+
kernel_size=int(options['k']),
|
394 |
+
stride=[int(options['s'][0])],
|
395 |
+
expand_ratio=int(options['e']),
|
396 |
+
input_filters=int(options['i']),
|
397 |
+
output_filters=int(options['o']),
|
398 |
+
se_ratio=float(options['se']) if 'se' in options else None,
|
399 |
+
id_skip=('noskip' not in block_string))
|
400 |
+
|
401 |
+
@staticmethod
|
402 |
+
def _encode_block_string(block):
|
403 |
+
"""Encode a block to a string.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
block (namedtuple): A BlockArgs type argument.
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
block_string: A String form of BlockArgs.
|
410 |
+
"""
|
411 |
+
args = [
|
412 |
+
'r%d' % block.num_repeat,
|
413 |
+
'k%d' % block.kernel_size,
|
414 |
+
's%d%d' % (block.strides[0], block.strides[1]),
|
415 |
+
'e%s' % block.expand_ratio,
|
416 |
+
'i%d' % block.input_filters,
|
417 |
+
'o%d' % block.output_filters
|
418 |
+
]
|
419 |
+
if 0 < block.se_ratio <= 1:
|
420 |
+
args.append('se%s' % block.se_ratio)
|
421 |
+
if block.id_skip is False:
|
422 |
+
args.append('noskip')
|
423 |
+
return '_'.join(args)
|
424 |
+
|
425 |
+
@staticmethod
|
426 |
+
def decode(string_list):
|
427 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
434 |
+
"""
|
435 |
+
assert isinstance(string_list, list)
|
436 |
+
blocks_args = []
|
437 |
+
for block_string in string_list:
|
438 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
439 |
+
return blocks_args
|
440 |
+
|
441 |
+
@staticmethod
|
442 |
+
def encode(blocks_args):
|
443 |
+
"""Encode a list of BlockArgs to a list of strings.
|
444 |
+
|
445 |
+
Args:
|
446 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
block_strings: A list of strings, each string is a notation of block.
|
450 |
+
"""
|
451 |
+
block_strings = []
|
452 |
+
for block in blocks_args:
|
453 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
454 |
+
return block_strings
|
455 |
+
|
456 |
+
|
457 |
+
def efficientnet_params(model_name):
|
458 |
+
"""Map EfficientNet model name to parameter coefficients.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
model_name (str): Model name to be queried.
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
params_dict[model_name]: A (width,depth,res,dropout) tuple.
|
465 |
+
"""
|
466 |
+
params_dict = {
|
467 |
+
# Coefficients: width,depth,res,dropout
|
468 |
+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
469 |
+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
470 |
+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
471 |
+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
472 |
+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
473 |
+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
474 |
+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
475 |
+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
476 |
+
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
|
477 |
+
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
|
478 |
+
}
|
479 |
+
return params_dict[model_name]
|
480 |
+
|
481 |
+
|
482 |
+
def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None,
|
483 |
+
dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True):
|
484 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
width_coefficient (float)
|
488 |
+
depth_coefficient (float)
|
489 |
+
image_size (int)
|
490 |
+
dropout_rate (float)
|
491 |
+
drop_connect_rate (float)
|
492 |
+
num_classes (int)
|
493 |
+
|
494 |
+
Meaning as the name suggests.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
blocks_args, global_params.
|
498 |
+
"""
|
499 |
+
|
500 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
501 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
502 |
+
blocks_args = [
|
503 |
+
'r1_k3_s11_e1_i32_o16_se0.25',
|
504 |
+
'r2_k3_s22_e6_i16_o24_se0.25',
|
505 |
+
'r2_k5_s22_e6_i24_o40_se0.25',
|
506 |
+
'r3_k3_s22_e6_i40_o80_se0.25',
|
507 |
+
'r3_k5_s11_e6_i80_o112_se0.25',
|
508 |
+
'r4_k5_s22_e6_i112_o192_se0.25',
|
509 |
+
'r1_k3_s11_e6_i192_o320_se0.25',
|
510 |
+
]
|
511 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
512 |
+
|
513 |
+
global_params = GlobalParams(
|
514 |
+
width_coefficient=width_coefficient,
|
515 |
+
depth_coefficient=depth_coefficient,
|
516 |
+
image_size=image_size,
|
517 |
+
dropout_rate=dropout_rate,
|
518 |
+
|
519 |
+
num_classes=num_classes,
|
520 |
+
batch_norm_momentum=0.99,
|
521 |
+
batch_norm_epsilon=1e-3,
|
522 |
+
drop_connect_rate=drop_connect_rate,
|
523 |
+
depth_divisor=8,
|
524 |
+
min_depth=None,
|
525 |
+
include_top=include_top,
|
526 |
+
)
|
527 |
+
|
528 |
+
return blocks_args, global_params
|
529 |
+
|
530 |
+
|
531 |
+
def get_model_params(model_name, override_params):
|
532 |
+
"""Get the block args and global params for a given model name.
|
533 |
+
|
534 |
+
Args:
|
535 |
+
model_name (str): Model's name.
|
536 |
+
override_params (dict): A dict to modify global_params.
|
537 |
+
|
538 |
+
Returns:
|
539 |
+
blocks_args, global_params
|
540 |
+
"""
|
541 |
+
if model_name.startswith('efficientnet'):
|
542 |
+
w, d, s, p = efficientnet_params(model_name)
|
543 |
+
# note: all models have drop connect rate = 0.2
|
544 |
+
blocks_args, global_params = efficientnet(
|
545 |
+
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
|
546 |
+
else:
|
547 |
+
raise NotImplementedError('model name is not pre-defined: {}'.format(model_name))
|
548 |
+
if override_params:
|
549 |
+
# ValueError will be raised here if override_params has fields not included in global_params.
|
550 |
+
global_params = global_params._replace(**override_params)
|
551 |
+
return blocks_args, global_params
|
552 |
+
|
553 |
+
|
554 |
+
# train with Standard methods
|
555 |
+
# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks)
|
556 |
+
url_map = {
|
557 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth',
|
558 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth',
|
559 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth',
|
560 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth',
|
561 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth',
|
562 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth',
|
563 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth',
|
564 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth',
|
565 |
+
}
|
566 |
+
|
567 |
+
# train with Adversarial Examples(AdvProp)
|
568 |
+
# check more details in paper(Adversarial Examples Improve Image Recognition)
|
569 |
+
url_map_advprop = {
|
570 |
+
'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth',
|
571 |
+
'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth',
|
572 |
+
'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth',
|
573 |
+
'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth',
|
574 |
+
'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth',
|
575 |
+
'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth',
|
576 |
+
'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth',
|
577 |
+
'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth',
|
578 |
+
'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth',
|
579 |
+
}
|
580 |
+
|
581 |
+
# TODO: add the petrained weights url map of 'efficientnet-l2'
|
582 |
+
|
583 |
+
|
584 |
+
def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True):
|
585 |
+
"""Loads pretrained weights from weights path or download using url.
|
586 |
+
|
587 |
+
Args:
|
588 |
+
model (Module): The whole model of efficientnet.
|
589 |
+
model_name (str): Model name of efficientnet.
|
590 |
+
weights_path (None or str):
|
591 |
+
str: path to pretrained weights file on the local disk.
|
592 |
+
None: use pretrained weights downloaded from the Internet.
|
593 |
+
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
|
594 |
+
advprop (bool): Whether to load pretrained weights
|
595 |
+
trained with advprop (valid when weights_path is None).
|
596 |
+
"""
|
597 |
+
if isinstance(weights_path, str):
|
598 |
+
state_dict = torch.load(weights_path)
|
599 |
+
else:
|
600 |
+
# AutoAugment or Advprop (different preprocessing)
|
601 |
+
url_map_ = url_map_advprop if advprop else url_map
|
602 |
+
state_dict = model_zoo.load_url(url_map_[model_name])
|
603 |
+
|
604 |
+
if load_fc:
|
605 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
606 |
+
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
607 |
+
else:
|
608 |
+
state_dict.pop('_fc.weight')
|
609 |
+
state_dict.pop('_fc.bias')
|
610 |
+
ret = model.load_state_dict(state_dict, strict=False)
|
611 |
+
assert set(ret.missing_keys) == set(
|
612 |
+
['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
|
613 |
+
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)
|
614 |
+
|
615 |
+
if verbose:
|
616 |
+
print('Loaded pretrained weights for {}'.format(model_name))
|
deepafx_st/models/encoder.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from deepafx_st.models.mobilenetv2 import MobileNetV2
|
4 |
+
from deepafx_st.models.efficient_net import EfficientNet
|
5 |
+
|
6 |
+
|
7 |
+
class SpectralEncoder(torch.nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
num_params,
|
11 |
+
sample_rate,
|
12 |
+
encoder_model="mobilenet_v2",
|
13 |
+
embed_dim=1028,
|
14 |
+
width_mult=1,
|
15 |
+
min_level_db=-80,
|
16 |
+
):
|
17 |
+
"""Encoder operating on spectrograms.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_params (int): Number of processor parameters to generate.
|
21 |
+
sample_rate (float): Audio sample rate for computing melspectrogram.
|
22 |
+
encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2"
|
23 |
+
embed_dim (int, optional): Dimentionality of the encoder representations.
|
24 |
+
width_mult (int, optional): Encoder size. Default: 1
|
25 |
+
min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80
|
26 |
+
"""
|
27 |
+
super().__init__()
|
28 |
+
self.num_params = num_params
|
29 |
+
self.sample_rate = sample_rate
|
30 |
+
self.encoder_model = encoder_model
|
31 |
+
self.embed_dim = embed_dim
|
32 |
+
self.width_mult = width_mult
|
33 |
+
self.min_level_db = min_level_db
|
34 |
+
|
35 |
+
# load model from torch.hub
|
36 |
+
if encoder_model == "mobilenet_v2":
|
37 |
+
self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult)
|
38 |
+
elif encoder_model == "efficient_net":
|
39 |
+
self.encoder = EfficientNet.from_name(
|
40 |
+
"efficientnet-b2",
|
41 |
+
in_channels=1,
|
42 |
+
image_size=(128, 65),
|
43 |
+
include_top=False,
|
44 |
+
)
|
45 |
+
self.embedding_projection = torch.nn.Conv2d(
|
46 |
+
in_channels=1408,
|
47 |
+
out_channels=embed_dim,
|
48 |
+
kernel_size=(1, 1),
|
49 |
+
stride=(1, 1),
|
50 |
+
padding=(0, 0),
|
51 |
+
bias=True,
|
52 |
+
)
|
53 |
+
|
54 |
+
else:
|
55 |
+
raise ValueError(f"Invalid encoder_model: {encoder_model}.")
|
56 |
+
|
57 |
+
self.window = torch.nn.Parameter(torch.hann_window(4096))
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
x (Tensor): Input waveform of shape [batch x channels x samples]
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim]
|
66 |
+
"""
|
67 |
+
bs, chs, samp = x.size()
|
68 |
+
|
69 |
+
# compute spectrogram of waveform
|
70 |
+
X = torch.stft(
|
71 |
+
x.view(bs, -1),
|
72 |
+
4096,
|
73 |
+
2048,
|
74 |
+
window=self.window,
|
75 |
+
return_complex=True,
|
76 |
+
)
|
77 |
+
X_db = torch.pow(X.abs() + 1e-8, 0.3)
|
78 |
+
X_db_norm = X_db
|
79 |
+
|
80 |
+
# standardize (0, 1) 0.322970 0.278452
|
81 |
+
X_db_norm -= 0.322970
|
82 |
+
X_db_norm /= 0.278452
|
83 |
+
X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2)
|
84 |
+
|
85 |
+
if self.encoder_model == "mobilenet_v2":
|
86 |
+
# repeat channels by 3 to fit vision model
|
87 |
+
X_db_norm = X_db_norm.repeat(1, 3, 1, 1)
|
88 |
+
|
89 |
+
# pass melspectrogram through encoder
|
90 |
+
e = self.encoder(X_db_norm)
|
91 |
+
|
92 |
+
# apply avg pooling across time for encoder embeddings
|
93 |
+
e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1)
|
94 |
+
|
95 |
+
# normalize by L2 norm
|
96 |
+
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
|
97 |
+
e_norm = e / norm
|
98 |
+
|
99 |
+
elif self.encoder_model == "efficient_net":
|
100 |
+
|
101 |
+
# Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest
|
102 |
+
e = self.encoder(X_db_norm)
|
103 |
+
|
104 |
+
# Adding 1x1 conv to project down or up to the requested embedding size
|
105 |
+
e = self.embedding_projection(e)
|
106 |
+
e = torch.squeeze(e, dim=3)
|
107 |
+
e = torch.squeeze(e, dim=2)
|
108 |
+
|
109 |
+
# normalize by L2 norm
|
110 |
+
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
|
111 |
+
e_norm = e / norm
|
112 |
+
|
113 |
+
return e_norm
|
deepafx_st/models/mobilenetv2.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
|
3 |
+
# Copyright (c) Soumith Chintala 2016,
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
# * Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# Adaptation of the PyTorch torchvision MobileNetV2 without a classifier.
|
32 |
+
# See source here: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html#mobilenet_v2
|
33 |
+
from torch import nn
|
34 |
+
|
35 |
+
|
36 |
+
def _make_divisible(v, divisor, min_value=None):
|
37 |
+
"""
|
38 |
+
This function is taken from the original tf repo.
|
39 |
+
It ensures that all layers have a channel number that is divisible by 8
|
40 |
+
It can be seen here:
|
41 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
42 |
+
:param v:
|
43 |
+
:param divisor:
|
44 |
+
:param min_value:
|
45 |
+
:return:
|
46 |
+
"""
|
47 |
+
if min_value is None:
|
48 |
+
min_value = divisor
|
49 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
50 |
+
# Make sure that round down does not go down by more than 10%.
|
51 |
+
if new_v < 0.9 * v:
|
52 |
+
new_v += divisor
|
53 |
+
return new_v
|
54 |
+
|
55 |
+
|
56 |
+
class ConvBNReLU(nn.Sequential):
|
57 |
+
def __init__(
|
58 |
+
self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None
|
59 |
+
):
|
60 |
+
padding = (kernel_size - 1) // 2
|
61 |
+
if norm_layer is None:
|
62 |
+
norm_layer = nn.BatchNorm2d
|
63 |
+
super(ConvBNReLU, self).__init__(
|
64 |
+
nn.Conv2d(
|
65 |
+
in_planes,
|
66 |
+
out_planes,
|
67 |
+
kernel_size,
|
68 |
+
stride,
|
69 |
+
padding,
|
70 |
+
groups=groups,
|
71 |
+
bias=False,
|
72 |
+
),
|
73 |
+
norm_layer(out_planes),
|
74 |
+
nn.ReLU6(inplace=True),
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
class InvertedResidual(nn.Module):
|
79 |
+
def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
|
80 |
+
super(InvertedResidual, self).__init__()
|
81 |
+
self.stride = stride
|
82 |
+
assert stride in [1, 2]
|
83 |
+
|
84 |
+
if norm_layer is None:
|
85 |
+
norm_layer = nn.BatchNorm2d
|
86 |
+
|
87 |
+
hidden_dim = int(round(inp * expand_ratio))
|
88 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
89 |
+
|
90 |
+
layers = []
|
91 |
+
if expand_ratio != 1:
|
92 |
+
# pw
|
93 |
+
layers.append(
|
94 |
+
ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)
|
95 |
+
)
|
96 |
+
layers.extend(
|
97 |
+
[
|
98 |
+
# dw
|
99 |
+
ConvBNReLU(
|
100 |
+
hidden_dim,
|
101 |
+
hidden_dim,
|
102 |
+
stride=stride,
|
103 |
+
groups=hidden_dim,
|
104 |
+
norm_layer=norm_layer,
|
105 |
+
),
|
106 |
+
# pw-linear
|
107 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
108 |
+
norm_layer(oup),
|
109 |
+
]
|
110 |
+
)
|
111 |
+
self.conv = nn.Sequential(*layers)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if self.use_res_connect:
|
115 |
+
return x + self.conv(x)
|
116 |
+
else:
|
117 |
+
return self.conv(x)
|
118 |
+
|
119 |
+
|
120 |
+
class MobileNetV2(nn.Module):
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
embed_dim=1028,
|
124 |
+
width_mult=1.0,
|
125 |
+
inverted_residual_setting=None,
|
126 |
+
round_nearest=8,
|
127 |
+
block=None,
|
128 |
+
norm_layer=None,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
MobileNet V2 main class
|
132 |
+
|
133 |
+
Args:
|
134 |
+
embed_dim (int): Number of channels in the final output.
|
135 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
136 |
+
inverted_residual_setting: Network structure
|
137 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
138 |
+
Set to 1 to turn off rounding
|
139 |
+
block: Module specifying inverted residual building block for mobilenet
|
140 |
+
norm_layer: Module specifying the normalization layer to use
|
141 |
+
|
142 |
+
"""
|
143 |
+
super(MobileNetV2, self).__init__()
|
144 |
+
|
145 |
+
if block is None:
|
146 |
+
block = InvertedResidual
|
147 |
+
|
148 |
+
if norm_layer is None:
|
149 |
+
norm_layer = nn.BatchNorm2d
|
150 |
+
|
151 |
+
input_channel = 32
|
152 |
+
last_channel = embed_dim / width_mult
|
153 |
+
|
154 |
+
if inverted_residual_setting is None:
|
155 |
+
inverted_residual_setting = [
|
156 |
+
# t, c, n, s
|
157 |
+
[1, 16, 1, 1],
|
158 |
+
[6, 24, 2, 2],
|
159 |
+
[6, 32, 3, 2],
|
160 |
+
[6, 64, 4, 2],
|
161 |
+
[6, 96, 3, 1],
|
162 |
+
[6, 160, 3, 2],
|
163 |
+
[6, 320, 1, 1],
|
164 |
+
]
|
165 |
+
|
166 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
167 |
+
if (
|
168 |
+
len(inverted_residual_setting) == 0
|
169 |
+
or len(inverted_residual_setting[0]) != 4
|
170 |
+
):
|
171 |
+
raise ValueError(
|
172 |
+
"inverted_residual_setting should be non-empty "
|
173 |
+
"or a 4-element list, got {}".format(inverted_residual_setting)
|
174 |
+
)
|
175 |
+
|
176 |
+
# building first layer
|
177 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
178 |
+
self.last_channel = _make_divisible(
|
179 |
+
last_channel * max(1.0, width_mult), round_nearest
|
180 |
+
)
|
181 |
+
features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
|
182 |
+
# building inverted residual blocks
|
183 |
+
for t, c, n, s in inverted_residual_setting:
|
184 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
185 |
+
for i in range(n):
|
186 |
+
stride = s if i == 0 else 1
|
187 |
+
features.append(
|
188 |
+
block(
|
189 |
+
input_channel,
|
190 |
+
output_channel,
|
191 |
+
stride,
|
192 |
+
expand_ratio=t,
|
193 |
+
norm_layer=norm_layer,
|
194 |
+
)
|
195 |
+
)
|
196 |
+
input_channel = output_channel
|
197 |
+
# building last several layers
|
198 |
+
features.append(
|
199 |
+
ConvBNReLU(
|
200 |
+
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer
|
201 |
+
)
|
202 |
+
)
|
203 |
+
# make it nn.Sequential
|
204 |
+
self.features = nn.Sequential(*features)
|
205 |
+
|
206 |
+
# weight initialization
|
207 |
+
for m in self.modules():
|
208 |
+
if isinstance(m, nn.Conv2d):
|
209 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
210 |
+
if m.bias is not None:
|
211 |
+
nn.init.zeros_(m.bias)
|
212 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
213 |
+
nn.init.ones_(m.weight)
|
214 |
+
nn.init.zeros_(m.bias)
|
215 |
+
elif isinstance(m, nn.Linear):
|
216 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
217 |
+
nn.init.zeros_(m.bias)
|
218 |
+
|
219 |
+
def _forward_impl(self, x):
|
220 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
221 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
222 |
+
return self.features(x)
|
223 |
+
# return the features directly, no classifier or pooling
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
return self._forward_impl(x)
|
deepafx_st/probes/cdpam_encoder.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2021 Pranay Manocha
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# code adapated from https://github.com/pranaymanocha/PerceptualAudio
|
24 |
+
|
25 |
+
import cdpam
|
26 |
+
import torch
|
27 |
+
|
28 |
+
|
29 |
+
class CDPAMEncoder(torch.nn.Module):
|
30 |
+
def __init__(self, cdpam_ckpt: str):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
# pre-trained model parameterss
|
34 |
+
encoder_layers = 16
|
35 |
+
encoder_filters = 64
|
36 |
+
input_size = 512
|
37 |
+
proj_ndim = [512, 256]
|
38 |
+
ndim = [16, 6]
|
39 |
+
classif_BN = 0
|
40 |
+
classif_act = "no"
|
41 |
+
proj_dp = 0.1
|
42 |
+
proj_BN = 1
|
43 |
+
classif_dp = 0.05
|
44 |
+
|
45 |
+
model = cdpam.models.FINnet(
|
46 |
+
encoder_layers=encoder_layers,
|
47 |
+
encoder_filters=encoder_filters,
|
48 |
+
ndim=ndim,
|
49 |
+
classif_dp=classif_dp,
|
50 |
+
classif_BN=classif_BN,
|
51 |
+
classif_act=classif_act,
|
52 |
+
input_size=input_size,
|
53 |
+
)
|
54 |
+
|
55 |
+
state = torch.load(cdpam_ckpt, map_location="cpu")["state"]
|
56 |
+
model.load_state_dict(state)
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
self.model = model
|
60 |
+
self.embed_dim = 512
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
_, a1, c1 = self.model.base_encoder.forward(x)
|
66 |
+
a1 = torch.nn.functional.normalize(a1, dim=1)
|
67 |
+
|
68 |
+
return a1
|
deepafx_st/probes/probe_system.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import julius
|
3 |
+
import torchopenl3
|
4 |
+
import torchmetrics
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from typing import Tuple, List, Dict
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
|
9 |
+
from deepafx_st.probes.cdpam_encoder import CDPAMEncoder
|
10 |
+
from deepafx_st.probes.random_mel import RandomMelProjection
|
11 |
+
|
12 |
+
import deepafx_st.utils as utils
|
13 |
+
from deepafx_st.utils import DSPMode
|
14 |
+
from deepafx_st.system import System
|
15 |
+
from deepafx_st.data.style import StyleDataset
|
16 |
+
|
17 |
+
|
18 |
+
class ProbeSystem(pl.LightningModule):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
audio_dir=None,
|
22 |
+
num_classes=5,
|
23 |
+
task="style",
|
24 |
+
encoder_type="deepafx_st_autodiff",
|
25 |
+
deepafx_st_autodiff_ckpt=None,
|
26 |
+
deepafx_st_spsa_ckpt=None,
|
27 |
+
deepafx_st_proxy0_ckpt=None,
|
28 |
+
probe_type="linear",
|
29 |
+
batch_size=32,
|
30 |
+
lr=3e-4,
|
31 |
+
lr_patience=20,
|
32 |
+
patience=10,
|
33 |
+
preload=False,
|
34 |
+
sample_rate=24000,
|
35 |
+
shuffle=True,
|
36 |
+
num_workers=16,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.save_hyperparameters()
|
41 |
+
|
42 |
+
if "deepafx_st" in self.hparams.encoder_type:
|
43 |
+
|
44 |
+
if "autodiff" in self.hparams.encoder_type:
|
45 |
+
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt
|
46 |
+
elif "spsa" in self.hparams.encoder_type:
|
47 |
+
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt
|
48 |
+
elif "proxy0" in self.hparams.encoder_type:
|
49 |
+
self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt
|
50 |
+
|
51 |
+
else:
|
52 |
+
raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}")
|
53 |
+
|
54 |
+
if self.hparams.deepafx_st_ckpt is None:
|
55 |
+
raise RuntimeError(
|
56 |
+
f"Must supply {self.hparams.encoder_type}_ckpt checkpoint."
|
57 |
+
)
|
58 |
+
use_dsp = DSPMode.NONE
|
59 |
+
system = System.load_from_checkpoint(
|
60 |
+
self.hparams.deepafx_st_ckpt,
|
61 |
+
use_dsp=use_dsp,
|
62 |
+
batch_size=self.hparams.batch_size,
|
63 |
+
spsa_parallel=False,
|
64 |
+
proxy_ckpts=[],
|
65 |
+
strict=False,
|
66 |
+
)
|
67 |
+
system.eval()
|
68 |
+
self.encoder = system.encoder
|
69 |
+
self.hparams.embed_dim = self.encoder.embed_dim
|
70 |
+
|
71 |
+
# freeze weights
|
72 |
+
for name, param in self.encoder.named_parameters():
|
73 |
+
param.requires_grad = False
|
74 |
+
|
75 |
+
elif self.hparams.encoder_type == "openl3":
|
76 |
+
self.encoder = torchopenl3.models.load_audio_embedding_model(
|
77 |
+
input_repr=self.hparams.openl3_input_repr,
|
78 |
+
embedding_size=self.hparams.openl3_embedding_size,
|
79 |
+
content_type=self.hparams.openl3_content_type,
|
80 |
+
)
|
81 |
+
self.hparams.embed_dim = 6144
|
82 |
+
elif self.hparams.encoder_type == "random_mel":
|
83 |
+
self.encoder = RandomMelProjection(
|
84 |
+
self.hparams.sample_rate,
|
85 |
+
self.hparams.random_mel_embedding_size,
|
86 |
+
self.hparams.random_mel_n_mels,
|
87 |
+
self.hparams.random_mel_n_fft,
|
88 |
+
self.hparams.random_mel_hop_size,
|
89 |
+
)
|
90 |
+
self.hparams.embed_dim = self.hparams.random_mel_embedding_size
|
91 |
+
elif self.hparams.encoder_type == "cdpam":
|
92 |
+
self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt)
|
93 |
+
self.encoder.eval()
|
94 |
+
self.hparams.embed_dim = self.encoder.embed_dim
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}")
|
97 |
+
|
98 |
+
if self.hparams.probe_type == "linear":
|
99 |
+
if self.hparams.task == "style":
|
100 |
+
self.probe = torch.nn.Sequential(
|
101 |
+
torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes),
|
102 |
+
# torch.nn.Softmax(-1),
|
103 |
+
)
|
104 |
+
elif self.hparams.probe_type == "mlp":
|
105 |
+
if self.hparams.task == "style":
|
106 |
+
self.probe = torch.nn.Sequential(
|
107 |
+
torch.nn.Linear(self.hparams.embed_dim, 512),
|
108 |
+
torch.nn.ReLU(),
|
109 |
+
torch.nn.Linear(512, 512),
|
110 |
+
torch.nn.ReLU(),
|
111 |
+
torch.nn.Linear(512, self.hparams.num_classes),
|
112 |
+
)
|
113 |
+
self.accuracy = torchmetrics.Accuracy()
|
114 |
+
self.f1_score = torchmetrics.F1Score(self.hparams.num_classes)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
bs, chs, samp = x.size()
|
118 |
+
with torch.no_grad():
|
119 |
+
if "deepafx_st" in self.hparams.encoder_type:
|
120 |
+
x /= x.abs().max()
|
121 |
+
x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom
|
122 |
+
e = self.encoder(x)
|
123 |
+
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
|
124 |
+
e = e / norm
|
125 |
+
elif self.hparams.encoder_type == "openl3":
|
126 |
+
# x = julius.resample_frac(x, self.hparams.sample_rate, 48000)
|
127 |
+
e, ts = torchopenl3.get_audio_embedding(
|
128 |
+
x,
|
129 |
+
48000,
|
130 |
+
model=self.encoder,
|
131 |
+
input_repr="mel128",
|
132 |
+
content_type="music",
|
133 |
+
)
|
134 |
+
e = e.permute(0, 2, 1)
|
135 |
+
e = e.mean(dim=-1)
|
136 |
+
# normalize by L2 norm
|
137 |
+
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
|
138 |
+
e = e / norm
|
139 |
+
elif self.hparams.encoder_type == "random_mel":
|
140 |
+
e = self.encoder(x)
|
141 |
+
norm = torch.norm(e, p=2, dim=-1, keepdim=True)
|
142 |
+
e = e / norm
|
143 |
+
elif self.hparams.encoder_type == "cdpam":
|
144 |
+
# x = julius.resample_frac(x, self.hparams.sample_rate, 22050)
|
145 |
+
x = torch.round(x * 32768)
|
146 |
+
e = self.encoder(x)
|
147 |
+
|
148 |
+
return self.probe(e)
|
149 |
+
|
150 |
+
def common_step(
|
151 |
+
self,
|
152 |
+
batch: Tuple,
|
153 |
+
batch_idx: int,
|
154 |
+
optimizer_idx: int = 0,
|
155 |
+
train: bool = True,
|
156 |
+
):
|
157 |
+
loss = 0
|
158 |
+
x, y = batch
|
159 |
+
|
160 |
+
y_hat = self(x)
|
161 |
+
|
162 |
+
# compute CE
|
163 |
+
if self.hparams.task == "style":
|
164 |
+
loss = torch.nn.functional.cross_entropy(y_hat, y)
|
165 |
+
|
166 |
+
if not train:
|
167 |
+
# store audio data
|
168 |
+
data_dict = {"x": x.float().cpu()}
|
169 |
+
else:
|
170 |
+
data_dict = {}
|
171 |
+
|
172 |
+
self.log(
|
173 |
+
"train_loss" if train else "val_loss",
|
174 |
+
loss,
|
175 |
+
on_step=True,
|
176 |
+
on_epoch=True,
|
177 |
+
prog_bar=False,
|
178 |
+
logger=True,
|
179 |
+
sync_dist=True,
|
180 |
+
)
|
181 |
+
|
182 |
+
if not train and self.hparams.task == "style":
|
183 |
+
self.log("val_acc_step", self.accuracy(y_hat, y))
|
184 |
+
self.log("val_f1_step", self.f1_score(y_hat, y))
|
185 |
+
|
186 |
+
return loss, data_dict
|
187 |
+
|
188 |
+
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
189 |
+
loss, _ = self.common_step(batch, batch_idx)
|
190 |
+
return loss
|
191 |
+
|
192 |
+
def validation_step(self, batch, batch_idx):
|
193 |
+
loss, data_dict = self.common_step(batch, batch_idx, train=False)
|
194 |
+
|
195 |
+
if batch_idx == 0:
|
196 |
+
return data_dict
|
197 |
+
|
198 |
+
def validation_epoch_end(self, outputs) -> None:
|
199 |
+
if self.hparams.task == "style":
|
200 |
+
self.log("val_acc_epoch", self.accuracy.compute())
|
201 |
+
self.log("val_f1_epoch", self.f1_score.compute())
|
202 |
+
|
203 |
+
return super().validation_epoch_end(outputs)
|
204 |
+
|
205 |
+
def configure_optimizers(self):
|
206 |
+
optimizer = torch.optim.AdamW(
|
207 |
+
self.probe.parameters(),
|
208 |
+
lr=self.hparams.lr,
|
209 |
+
betas=(0.9, 0.999),
|
210 |
+
)
|
211 |
+
|
212 |
+
ms1 = int(self.hparams.max_epochs * 0.8)
|
213 |
+
ms2 = int(self.hparams.max_epochs * 0.95)
|
214 |
+
print(
|
215 |
+
"Learning rate schedule:",
|
216 |
+
f"0 {self.hparams.lr:0.2e} -> ",
|
217 |
+
f"{ms1} {self.hparams.lr*0.1:0.2e} -> ",
|
218 |
+
f"{ms2} {self.hparams.lr*0.01:0.2e}",
|
219 |
+
)
|
220 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
221 |
+
optimizer,
|
222 |
+
milestones=[ms1, ms2],
|
223 |
+
gamma=0.1,
|
224 |
+
)
|
225 |
+
|
226 |
+
return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
|
227 |
+
|
228 |
+
def train_dataloader(self):
|
229 |
+
|
230 |
+
if self.hparams.task == "style":
|
231 |
+
train_dataset = StyleDataset(
|
232 |
+
self.hparams.audio_dir,
|
233 |
+
"train",
|
234 |
+
sample_rate=self.hparams.encoder_sample_rate,
|
235 |
+
)
|
236 |
+
|
237 |
+
g = torch.Generator()
|
238 |
+
g.manual_seed(0)
|
239 |
+
|
240 |
+
return torch.utils.data.DataLoader(
|
241 |
+
train_dataset,
|
242 |
+
num_workers=self.hparams.num_workers,
|
243 |
+
batch_size=self.hparams.batch_size,
|
244 |
+
shuffle=True,
|
245 |
+
worker_init_fn=utils.seed_worker,
|
246 |
+
generator=g,
|
247 |
+
pin_memory=True,
|
248 |
+
)
|
249 |
+
|
250 |
+
def val_dataloader(self):
|
251 |
+
|
252 |
+
if self.hparams.task == "style":
|
253 |
+
val_dataset = StyleDataset(
|
254 |
+
self.hparams.audio_dir,
|
255 |
+
subset="val",
|
256 |
+
sample_rate=self.hparams.encoder_sample_rate,
|
257 |
+
)
|
258 |
+
|
259 |
+
g = torch.Generator()
|
260 |
+
g.manual_seed(0)
|
261 |
+
|
262 |
+
return torch.utils.data.DataLoader(
|
263 |
+
val_dataset,
|
264 |
+
num_workers=self.hparams.num_workers,
|
265 |
+
batch_size=self.hparams.batch_size,
|
266 |
+
worker_init_fn=utils.seed_worker,
|
267 |
+
generator=g,
|
268 |
+
pin_memory=True,
|
269 |
+
)
|
270 |
+
|
271 |
+
# add any model hyperparameters here
|
272 |
+
@staticmethod
|
273 |
+
def add_model_specific_args(parent_parser):
|
274 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
275 |
+
# --- Model ---
|
276 |
+
parser.add_argument("--encoder_type", type=str, default="deeapfx2")
|
277 |
+
parser.add_argument("--probe_type", type=str, default="linear")
|
278 |
+
parser.add_argument("--task", type=str, default="style")
|
279 |
+
parser.add_argument("--encoder_sample_rate", type=int, default=24000)
|
280 |
+
# --- deeapfx2 ---
|
281 |
+
parser.add_argument("--deepafx_st_autodiff_ckpt", type=str)
|
282 |
+
parser.add_argument("--deepafx_st_spsa_ckpt", type=str)
|
283 |
+
parser.add_argument("--deepafx_st_proxy0_ckpt", type=str)
|
284 |
+
|
285 |
+
# --- cdpam ---
|
286 |
+
parser.add_argument("--cdpam_ckpt", type=str)
|
287 |
+
# --- openl3 ---
|
288 |
+
parser.add_argument("--openl3_input_repr", type=str, default="mel128")
|
289 |
+
parser.add_argument("--openl3_content_type", type=str, default="env")
|
290 |
+
parser.add_argument("--openl3_embedding_size", type=int, default=6144)
|
291 |
+
# --- random_mel ---
|
292 |
+
parser.add_argument("--random_mel_embedding_size", type=str, default=4096)
|
293 |
+
parser.add_argument("--random_mel_n_fft", type=str, default=4096)
|
294 |
+
parser.add_argument("--random_mel_hop_size", type=str, default=1024)
|
295 |
+
parser.add_argument("--random_mel_n_mels", type=str, default=128)
|
296 |
+
# --- Training ---
|
297 |
+
parser.add_argument("--audio_dir", type=str)
|
298 |
+
parser.add_argument("--num_classes", type=int, default=5)
|
299 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
300 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
301 |
+
parser.add_argument("--lr_patience", type=int, default=20)
|
302 |
+
parser.add_argument("--patience", type=int, default=10)
|
303 |
+
parser.add_argument("--preload", action="store_true")
|
304 |
+
parser.add_argument("--sample_rate", type=int, default=24000)
|
305 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
306 |
+
|
307 |
+
return parser
|
deepafx_st/probes/random_mel.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
|
5 |
+
# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py
|
6 |
+
|
7 |
+
|
8 |
+
class RandomMelProjection(torch.nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
sample_rate,
|
12 |
+
embed_dim=4096,
|
13 |
+
n_mels=128,
|
14 |
+
n_fft=4096,
|
15 |
+
hop_size=1024,
|
16 |
+
seed=0,
|
17 |
+
epsilon=1e-4,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.sample_rate = sample_rate
|
21 |
+
self.embed_dim = embed_dim
|
22 |
+
self.n_mels = n_mels
|
23 |
+
self.n_fft = n_fft
|
24 |
+
self.hop_size = hop_size
|
25 |
+
self.seed = seed
|
26 |
+
self.epsilon = epsilon
|
27 |
+
|
28 |
+
# Set random seed
|
29 |
+
torch.random.manual_seed(self.seed)
|
30 |
+
|
31 |
+
# Create a Hann window buffer to apply to frames prior to FFT.
|
32 |
+
self.register_buffer("window", torch.hann_window(self.n_fft))
|
33 |
+
|
34 |
+
# Create a mel filter buffer.
|
35 |
+
mel_scale = torch.tensor(
|
36 |
+
librosa.filters.mel(
|
37 |
+
self.sample_rate,
|
38 |
+
n_fft=self.n_fft,
|
39 |
+
n_mels=self.n_mels,
|
40 |
+
)
|
41 |
+
)
|
42 |
+
self.register_buffer("mel_scale", mel_scale)
|
43 |
+
|
44 |
+
# Projection matrices.
|
45 |
+
normalization = math.sqrt(self.n_mels)
|
46 |
+
self.projection = torch.nn.Parameter(
|
47 |
+
torch.rand(self.n_mels, self.embed_dim) / normalization,
|
48 |
+
requires_grad=False,
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
bs, chs, samp = x.size()
|
53 |
+
|
54 |
+
x = torch.stft(
|
55 |
+
x.view(bs, -1),
|
56 |
+
self.n_fft,
|
57 |
+
self.hop_size,
|
58 |
+
window=self.window,
|
59 |
+
return_complex=True,
|
60 |
+
)
|
61 |
+
x = x.unsqueeze(1).permute(0, 1, 3, 2)
|
62 |
+
|
63 |
+
# Apply the mel-scale filter to the power spectrum.
|
64 |
+
x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1))
|
65 |
+
|
66 |
+
# power scale
|
67 |
+
x = torch.pow(x + self.epsilon, 0.3)
|
68 |
+
|
69 |
+
# apply random projection
|
70 |
+
e = x.matmul(self.projection)
|
71 |
+
|
72 |
+
# take mean across temporal dim
|
73 |
+
e = e.mean(dim=2).view(bs, -1)
|
74 |
+
|
75 |
+
return e
|
76 |
+
|
77 |
+
def compute_frame_embedding(self, x):
|
78 |
+
# Compute the real-valued Fourier transform on windowed input signal.
|
79 |
+
x = torch.fft.rfft(x * self.window)
|
80 |
+
|
81 |
+
# Convert to a power spectrum.
|
82 |
+
x = torch.abs(x) ** 2.0
|
83 |
+
|
84 |
+
# Apply the mel-scale filter to the power spectrum.
|
85 |
+
x = torch.matmul(x, self.mel_scale.transpose(0, 1))
|
86 |
+
|
87 |
+
# Convert to a log mel spectrum.
|
88 |
+
x = torch.log(x + self.epsilon)
|
89 |
+
|
90 |
+
# Apply projection to get a 4096 dimension embedding
|
91 |
+
embedding = x.matmul(self.projection)
|
92 |
+
|
93 |
+
return embedding
|
deepafx_st/processors/autodiff/__init__.py
ADDED
File without changes
|
deepafx_st/processors/autodiff/channel.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from deepafx_st.processors.autodiff.compressor import Compressor
|
4 |
+
from deepafx_st.processors.autodiff.peq import ParametricEQ
|
5 |
+
from deepafx_st.processors.autodiff.fir import FIRFilter
|
6 |
+
|
7 |
+
|
8 |
+
class AutodiffChannel(torch.nn.Module):
|
9 |
+
def __init__(self, sample_rate):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.peq = ParametricEQ(sample_rate)
|
13 |
+
self.comp = Compressor(sample_rate)
|
14 |
+
self.ports = [self.peq.ports, self.comp.ports]
|
15 |
+
self.num_control_params = (
|
16 |
+
self.peq.num_control_params + self.comp.num_control_params
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, x, p, sample_rate=24000, **kwargs):
|
20 |
+
|
21 |
+
# split params between EQ and Comp.
|
22 |
+
p_peq = p[:, : self.peq.num_control_params]
|
23 |
+
p_comp = p[:, self.peq.num_control_params :]
|
24 |
+
|
25 |
+
y = self.peq(x, p_peq, sample_rate)
|
26 |
+
y = self.comp(y, p_comp, sample_rate)
|
27 |
+
|
28 |
+
return y
|
deepafx_st/processors/autodiff/compressor.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import scipy.signal
|
4 |
+
|
5 |
+
import deepafx_st.processors.autodiff.signal
|
6 |
+
from deepafx_st.processors.processor import Processor
|
7 |
+
|
8 |
+
|
9 |
+
@torch.jit.script
|
10 |
+
def compressor(
|
11 |
+
x: torch.Tensor,
|
12 |
+
sample_rate: float,
|
13 |
+
threshold: torch.Tensor,
|
14 |
+
ratio: torch.Tensor,
|
15 |
+
attack_time: torch.Tensor,
|
16 |
+
release_time: torch.Tensor,
|
17 |
+
knee_dB: torch.Tensor,
|
18 |
+
makeup_gain_dB: torch.Tensor,
|
19 |
+
eps: float = 1e-8,
|
20 |
+
):
|
21 |
+
"""Note the `release` parameter is not used."""
|
22 |
+
# print(f"autodiff comp fs = {sample_rate}")
|
23 |
+
|
24 |
+
s = x.size() # should be one 1d
|
25 |
+
|
26 |
+
threshold = threshold.squeeze()
|
27 |
+
ratio = ratio.squeeze()
|
28 |
+
attack_time = attack_time.squeeze()
|
29 |
+
makeup_gain_dB = makeup_gain_dB.squeeze()
|
30 |
+
|
31 |
+
# uni-polar dB signal
|
32 |
+
# Turn the input signal into a uni-polar signal on the dB scale
|
33 |
+
x_G = 20 * torch.log10(torch.abs(x) + 1e-8) # x_uni casts type
|
34 |
+
|
35 |
+
# Ensure there are no values of negative infinity
|
36 |
+
x_G = torch.clamp(x_G, min=-96)
|
37 |
+
|
38 |
+
# Static characteristics with knee
|
39 |
+
y_G = torch.zeros(s).type_as(x)
|
40 |
+
|
41 |
+
ratio = ratio.view(-1)
|
42 |
+
threshold = threshold.view(-1)
|
43 |
+
attack_time = attack_time.view(-1)
|
44 |
+
release_time = release_time.view(-1)
|
45 |
+
knee_dB = knee_dB.view(-1)
|
46 |
+
makeup_gain_dB = makeup_gain_dB.view(-1)
|
47 |
+
|
48 |
+
# Below knee
|
49 |
+
idx = torch.where((2 * (x_G - threshold)) < -knee_dB)[0]
|
50 |
+
y_G[idx] = x_G[idx]
|
51 |
+
|
52 |
+
# At knee
|
53 |
+
idx = torch.where((2 * torch.abs(x_G - threshold)) <= knee_dB)[0]
|
54 |
+
y_G[idx] = x_G[idx] + (
|
55 |
+
(1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2)
|
56 |
+
) / (2 * knee_dB)
|
57 |
+
|
58 |
+
# Above knee threshold
|
59 |
+
idx = torch.where((2 * (x_G - threshold)) > knee_dB)[0]
|
60 |
+
y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio)
|
61 |
+
|
62 |
+
x_L = x_G - y_G
|
63 |
+
|
64 |
+
# design 1-pole butterworth lowpass
|
65 |
+
fc = 1.0 / (attack_time * sample_rate)
|
66 |
+
b, a = deepafx_st.processors.autodiff.signal.butter(fc)
|
67 |
+
|
68 |
+
# apply FIR approx of IIR filter
|
69 |
+
y_L = deepafx_st.processors.autodiff.signal.approx_iir_filter(b, a, x_L)
|
70 |
+
|
71 |
+
lin_y_L = torch.pow(10.0, -y_L / 20.0) # convert back to linear
|
72 |
+
y = lin_y_L * x # apply gain
|
73 |
+
|
74 |
+
# apply makeup gain
|
75 |
+
y *= torch.pow(10.0, makeup_gain_dB / 20.0)
|
76 |
+
|
77 |
+
return y
|
78 |
+
|
79 |
+
|
80 |
+
class Compressor(Processor):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
sample_rate,
|
84 |
+
max_threshold=0.0,
|
85 |
+
min_threshold=-80,
|
86 |
+
max_ratio=20.0,
|
87 |
+
min_ratio=1.0,
|
88 |
+
max_attack=0.1,
|
89 |
+
min_attack=0.0001,
|
90 |
+
max_release=1.0,
|
91 |
+
min_release=0.005,
|
92 |
+
max_knee=12.0,
|
93 |
+
min_knee=0.0,
|
94 |
+
max_mkgain=48.0,
|
95 |
+
min_mkgain=-48.0,
|
96 |
+
eps=1e-8,
|
97 |
+
):
|
98 |
+
""" """
|
99 |
+
super().__init__()
|
100 |
+
self.sample_rate = sample_rate
|
101 |
+
self.eps = eps
|
102 |
+
self.ports = [
|
103 |
+
{
|
104 |
+
"name": "Threshold",
|
105 |
+
"min": min_threshold,
|
106 |
+
"max": max_threshold,
|
107 |
+
"default": -12.0,
|
108 |
+
"units": "dB",
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"name": "Ratio",
|
112 |
+
"min": min_ratio,
|
113 |
+
"max": max_ratio,
|
114 |
+
"default": 2.0,
|
115 |
+
"units": "",
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"name": "Attack",
|
119 |
+
"min": min_attack,
|
120 |
+
"max": max_attack,
|
121 |
+
"default": 0.001,
|
122 |
+
"units": "s",
|
123 |
+
},
|
124 |
+
{
|
125 |
+
# this is a dummy parameter
|
126 |
+
"name": "Release (dummy)",
|
127 |
+
"min": min_release,
|
128 |
+
"max": max_release,
|
129 |
+
"default": 0.045,
|
130 |
+
"units": "s",
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"name": "Knee",
|
134 |
+
"min": min_knee,
|
135 |
+
"max": max_knee,
|
136 |
+
"default": 6.0,
|
137 |
+
"units": "dB",
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"name": "Makeup Gain",
|
141 |
+
"min": min_mkgain,
|
142 |
+
"max": max_mkgain,
|
143 |
+
"default": 0.0,
|
144 |
+
"units": "dB",
|
145 |
+
},
|
146 |
+
]
|
147 |
+
|
148 |
+
self.num_control_params = len(self.ports)
|
149 |
+
|
150 |
+
def forward(self, x, p, sample_rate=24000, **kwargs):
|
151 |
+
"""
|
152 |
+
|
153 |
+
Assume that parameters in p are normalized between 0 and 1.
|
154 |
+
|
155 |
+
x (tensor): Shape batch x 1 x samples
|
156 |
+
p (tensor): shape batch x params
|
157 |
+
|
158 |
+
"""
|
159 |
+
bs, ch, s = x.size()
|
160 |
+
|
161 |
+
inputs = torch.split(x, 1, 0)
|
162 |
+
params = torch.split(p, 1, 0)
|
163 |
+
|
164 |
+
y = [] # loop over batch dimension
|
165 |
+
for input, param in zip(inputs, params):
|
166 |
+
denorm_param = self.denormalize_params(param.view(-1))
|
167 |
+
y.append(compressor(input.view(-1), sample_rate, *denorm_param))
|
168 |
+
|
169 |
+
return torch.stack(y, dim=0).view(bs, 1, -1)
|
deepafx_st/processors/autodiff/fir.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class FIRFilter(torch.nn.Module):
|
5 |
+
def __init__(self, num_control_params=63):
|
6 |
+
super().__init__()
|
7 |
+
self.num_control_params = num_control_params
|
8 |
+
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
|
9 |
+
#self.batched_lfilter = torch.vmap(self.lfilter)
|
10 |
+
|
11 |
+
def forward(self, x, b, **kwargs):
|
12 |
+
"""Forward pass by appling FIR filter to each batch element.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
x (tensor): Input signals with shape (batch x 1 x samples)
|
16 |
+
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
|
17 |
+
|
18 |
+
"""
|
19 |
+
bs, ch, s = x.size()
|
20 |
+
b = self.adaptor(b)
|
21 |
+
|
22 |
+
# pad input
|
23 |
+
x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2))
|
24 |
+
|
25 |
+
# add extra dim for virutal batch dim
|
26 |
+
x = x.view(bs, 1, ch, -1)
|
27 |
+
b = b.view(bs, 1, 1, -1)
|
28 |
+
|
29 |
+
# exlcuding vmap for now
|
30 |
+
y = self.batched_lfilter(x, b).view(bs, ch, s)
|
31 |
+
|
32 |
+
return y
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def lfilter(x, b):
|
36 |
+
return torch.nn.functional.conv1d(x, b)
|
37 |
+
|
38 |
+
|
39 |
+
class FrequencyDomainFIRFilter(torch.nn.Module):
|
40 |
+
def __init__(self, num_control_params=31):
|
41 |
+
super().__init__()
|
42 |
+
self.num_control_params = num_control_params
|
43 |
+
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
|
44 |
+
|
45 |
+
def forward(self, x, b, **kwargs):
|
46 |
+
"""Forward pass by appling FIR filter to each batch element.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
x (tensor): Input signals with shape (batch x 1 x samples)
|
50 |
+
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
|
51 |
+
"""
|
52 |
+
bs, c, s = x.size()
|
53 |
+
|
54 |
+
b = self.adaptor(b)
|
55 |
+
|
56 |
+
# transform input to freq. domain
|
57 |
+
X = torch.fft.rfft(x.view(bs, -1))
|
58 |
+
|
59 |
+
# frequency response of filter
|
60 |
+
H = torch.fft.rfft(b.view(bs, -1))
|
61 |
+
|
62 |
+
# apply filter as multiplication in freq. domain
|
63 |
+
Y = X * H
|
64 |
+
|
65 |
+
# transform back to time domain
|
66 |
+
y = torch.fft.ifft(Y).view(bs, 1, -1)
|
67 |
+
|
68 |
+
return y
|
deepafx_st/processors/autodiff/peq.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import deepafx_st.processors.autodiff.signal
|
4 |
+
from deepafx_st.processors.processor import Processor
|
5 |
+
|
6 |
+
|
7 |
+
@torch.jit.script
|
8 |
+
def parametric_eq(
|
9 |
+
x: torch.Tensor,
|
10 |
+
sample_rate: float,
|
11 |
+
low_shelf_gain_dB: torch.Tensor,
|
12 |
+
low_shelf_cutoff_freq: torch.Tensor,
|
13 |
+
low_shelf_q_factor: torch.Tensor,
|
14 |
+
first_band_gain_dB: torch.Tensor,
|
15 |
+
first_band_cutoff_freq: torch.Tensor,
|
16 |
+
first_band_q_factor: torch.Tensor,
|
17 |
+
second_band_gain_dB: torch.Tensor,
|
18 |
+
second_band_cutoff_freq: torch.Tensor,
|
19 |
+
second_band_q_factor: torch.Tensor,
|
20 |
+
third_band_gain_dB: torch.Tensor,
|
21 |
+
third_band_cutoff_freq: torch.Tensor,
|
22 |
+
third_band_q_factor: torch.Tensor,
|
23 |
+
fourth_band_gain_dB: torch.Tensor,
|
24 |
+
fourth_band_cutoff_freq: torch.Tensor,
|
25 |
+
fourth_band_q_factor: torch.Tensor,
|
26 |
+
high_shelf_gain_dB: torch.Tensor,
|
27 |
+
high_shelf_cutoff_freq: torch.Tensor,
|
28 |
+
high_shelf_q_factor: torch.Tensor,
|
29 |
+
):
|
30 |
+
"""Six-band parametric EQ.
|
31 |
+
|
32 |
+
Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): 1d signal.
|
36 |
+
|
37 |
+
|
38 |
+
"""
|
39 |
+
a_s, b_s = [], []
|
40 |
+
#print(f"autodiff peq fs = {sample_rate}")
|
41 |
+
|
42 |
+
# -------- apply low-shelf filter --------
|
43 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
44 |
+
low_shelf_gain_dB,
|
45 |
+
low_shelf_cutoff_freq,
|
46 |
+
low_shelf_q_factor,
|
47 |
+
sample_rate,
|
48 |
+
"low_shelf",
|
49 |
+
)
|
50 |
+
b_s.append(b)
|
51 |
+
a_s.append(a)
|
52 |
+
|
53 |
+
# -------- apply first-band peaking filter --------
|
54 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
55 |
+
first_band_gain_dB,
|
56 |
+
first_band_cutoff_freq,
|
57 |
+
first_band_q_factor,
|
58 |
+
sample_rate,
|
59 |
+
"peaking",
|
60 |
+
)
|
61 |
+
b_s.append(b)
|
62 |
+
a_s.append(a)
|
63 |
+
|
64 |
+
# -------- apply second-band peaking filter --------
|
65 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
66 |
+
second_band_gain_dB,
|
67 |
+
second_band_cutoff_freq,
|
68 |
+
second_band_q_factor,
|
69 |
+
sample_rate,
|
70 |
+
"peaking",
|
71 |
+
)
|
72 |
+
b_s.append(b)
|
73 |
+
a_s.append(a)
|
74 |
+
|
75 |
+
# -------- apply third-band peaking filter --------
|
76 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
77 |
+
third_band_gain_dB,
|
78 |
+
third_band_cutoff_freq,
|
79 |
+
third_band_q_factor,
|
80 |
+
sample_rate,
|
81 |
+
"peaking",
|
82 |
+
)
|
83 |
+
b_s.append(b)
|
84 |
+
a_s.append(a)
|
85 |
+
|
86 |
+
# -------- apply fourth-band peaking filter --------
|
87 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
88 |
+
fourth_band_gain_dB,
|
89 |
+
fourth_band_cutoff_freq,
|
90 |
+
fourth_band_q_factor,
|
91 |
+
sample_rate,
|
92 |
+
"peaking",
|
93 |
+
)
|
94 |
+
b_s.append(b)
|
95 |
+
a_s.append(a)
|
96 |
+
|
97 |
+
# -------- apply high-shelf filter --------
|
98 |
+
b, a = deepafx_st.processors.autodiff.signal.biqaud(
|
99 |
+
high_shelf_gain_dB,
|
100 |
+
high_shelf_cutoff_freq,
|
101 |
+
high_shelf_q_factor,
|
102 |
+
sample_rate,
|
103 |
+
"high_shelf",
|
104 |
+
)
|
105 |
+
b_s.append(b)
|
106 |
+
a_s.append(a)
|
107 |
+
|
108 |
+
x = deepafx_st.processors.autodiff.signal.approx_iir_filter_cascade(
|
109 |
+
b_s, a_s, x.view(-1)
|
110 |
+
)
|
111 |
+
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class ParametricEQ(Processor):
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
sample_rate,
|
119 |
+
min_gain_dB=-24.0,
|
120 |
+
default_gain_dB=0.0,
|
121 |
+
max_gain_dB=24.0,
|
122 |
+
min_q_factor=0.1,
|
123 |
+
default_q_factor=0.707,
|
124 |
+
max_q_factor=10,
|
125 |
+
eps=1e-8,
|
126 |
+
):
|
127 |
+
""" """
|
128 |
+
super().__init__()
|
129 |
+
self.sample_rate = sample_rate
|
130 |
+
self.eps = eps
|
131 |
+
self.ports = [
|
132 |
+
{
|
133 |
+
"name": "Lowshelf gain",
|
134 |
+
"min": min_gain_dB,
|
135 |
+
"max": max_gain_dB,
|
136 |
+
"default": default_gain_dB,
|
137 |
+
"units": "dB",
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"name": "Lowshelf cutoff",
|
141 |
+
"min": 20.0,
|
142 |
+
"max": 200.0,
|
143 |
+
"default": 100.0,
|
144 |
+
"units": "Hz",
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"name": "Lowshelf Q",
|
148 |
+
"min": min_q_factor,
|
149 |
+
"max": max_q_factor,
|
150 |
+
"default": default_q_factor,
|
151 |
+
"units": "",
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"name": "First band gain",
|
155 |
+
"min": min_gain_dB,
|
156 |
+
"max": max_gain_dB,
|
157 |
+
"default": default_gain_dB,
|
158 |
+
"units": "dB",
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"name": "First band cutoff",
|
162 |
+
"min": 200.0,
|
163 |
+
"max": 2000.0,
|
164 |
+
"default": 400.0,
|
165 |
+
"units": "Hz",
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"name": "First band Q",
|
169 |
+
"min": min_q_factor,
|
170 |
+
"max": max_q_factor,
|
171 |
+
"default": 0.707,
|
172 |
+
"units": "",
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"name": "Second band gain",
|
176 |
+
"min": min_gain_dB,
|
177 |
+
"max": max_gain_dB,
|
178 |
+
"default": default_gain_dB,
|
179 |
+
"units": "dB",
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"name": "Second band cutoff",
|
183 |
+
"min": 200.0,
|
184 |
+
"max": 4000.0,
|
185 |
+
"default": 1000.0,
|
186 |
+
"units": "Hz",
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"name": "Second band Q",
|
190 |
+
"min": min_q_factor,
|
191 |
+
"max": max_q_factor,
|
192 |
+
"default": default_q_factor,
|
193 |
+
"units": "",
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"name": "Third band gain",
|
197 |
+
"min": min_gain_dB,
|
198 |
+
"max": max_gain_dB,
|
199 |
+
"default": default_gain_dB,
|
200 |
+
"units": "dB",
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"name": "Third band cutoff",
|
204 |
+
"min": 2000.0,
|
205 |
+
"max": 8000.0,
|
206 |
+
"default": 4000.0,
|
207 |
+
"units": "Hz",
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"name": "Third band Q",
|
211 |
+
"min": min_q_factor,
|
212 |
+
"max": max_q_factor,
|
213 |
+
"default": default_q_factor,
|
214 |
+
"units": "",
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"name": "Fourth band gain",
|
218 |
+
"min": min_gain_dB,
|
219 |
+
"max": max_gain_dB,
|
220 |
+
"default": default_gain_dB,
|
221 |
+
"units": "dB",
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"name": "Fourth band cutoff",
|
225 |
+
"min": 4000.0,
|
226 |
+
"max": (24000 // 2) * 0.9,
|
227 |
+
"default": 8000.0,
|
228 |
+
"units": "Hz",
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"name": "Fourth band Q",
|
232 |
+
"min": min_q_factor,
|
233 |
+
"max": max_q_factor,
|
234 |
+
"default": default_q_factor,
|
235 |
+
"units": "",
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"name": "Highshelf gain",
|
239 |
+
"min": min_gain_dB,
|
240 |
+
"max": max_gain_dB,
|
241 |
+
"default": default_gain_dB,
|
242 |
+
"units": "dB",
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"name": "Highshelf cutoff",
|
246 |
+
"min": 4000.0,
|
247 |
+
"max": (24000 // 2) * 0.9,
|
248 |
+
"default": 8000.0,
|
249 |
+
"units": "Hz",
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"name": "Highshelf Q",
|
253 |
+
"min": min_q_factor,
|
254 |
+
"max": max_q_factor,
|
255 |
+
"default": default_q_factor,
|
256 |
+
"units": "",
|
257 |
+
},
|
258 |
+
]
|
259 |
+
|
260 |
+
self.num_control_params = len(self.ports)
|
261 |
+
|
262 |
+
def forward(self, x, p, sample_rate=24000, **kwargs):
|
263 |
+
|
264 |
+
bs, chs, s = x.size()
|
265 |
+
|
266 |
+
inputs = torch.split(x, 1, 0)
|
267 |
+
params = torch.split(p, 1, 0)
|
268 |
+
|
269 |
+
y = [] # loop over batch dimension
|
270 |
+
for input, param in zip(inputs, params):
|
271 |
+
denorm_param = self.denormalize_params(param.view(-1))
|
272 |
+
y.append(parametric_eq(input.view(-1), sample_rate, *denorm_param))
|
273 |
+
|
274 |
+
return torch.stack(y, dim=0).view(bs, 1, -1)
|
deepafx_st/processors/autodiff/signal.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
def butter(fc, fs: float = 2.0):
|
7 |
+
"""
|
8 |
+
|
9 |
+
Recall Butterworth polynomials
|
10 |
+
N = 1 s + 1
|
11 |
+
N = 2 s^2 + sqrt(2s) + 1
|
12 |
+
N = 3 (s^2 + s + 1)(s + 1)
|
13 |
+
N = 4 (s^2 + 0.76536s + 1)(s^2 + 1.84776s + 1)
|
14 |
+
|
15 |
+
Scaling
|
16 |
+
LP to LP: s -> s/w_c
|
17 |
+
LP to HP: s -> w_c/s
|
18 |
+
|
19 |
+
Bilinear transform:
|
20 |
+
s = 2/T_d * (1 - z^-1)/(1 + z^-1)
|
21 |
+
|
22 |
+
For 1-pole butterworth lowpass
|
23 |
+
|
24 |
+
1 / (s + 1) 1-pole prototype
|
25 |
+
1 / (s/w_c + 1) LP to LP
|
26 |
+
1 / (2/T_d * (1 - z^-1)/(1 + z^-1))/w_c + 1) Bilinear transform
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
# apply pre-warping to the cutoff
|
31 |
+
T_d = 1 / fs
|
32 |
+
w_d = (2 * math.pi * fc) / fs
|
33 |
+
# sys.exit()
|
34 |
+
w_c = (2 / T_d) * torch.tan(w_d / 2)
|
35 |
+
|
36 |
+
a0 = 2 + (T_d * w_c)
|
37 |
+
a1 = (T_d * w_c) - 2
|
38 |
+
b0 = T_d * w_c
|
39 |
+
b1 = T_d * w_c
|
40 |
+
|
41 |
+
b = torch.stack([b0, b1], dim=0).view(-1)
|
42 |
+
a = torch.stack([a0, a1], dim=0).view(-1)
|
43 |
+
|
44 |
+
# normalize
|
45 |
+
b = b.type_as(fc) / a0
|
46 |
+
a = a.type_as(fc) / a0
|
47 |
+
|
48 |
+
return b, a
|
49 |
+
|
50 |
+
|
51 |
+
def biqaud(
|
52 |
+
gain_dB: torch.Tensor,
|
53 |
+
cutoff_freq: torch.Tensor,
|
54 |
+
q_factor: torch.Tensor,
|
55 |
+
sample_rate: float,
|
56 |
+
filter_type: str = "peaking",
|
57 |
+
):
|
58 |
+
|
59 |
+
# convert inputs to Tensors if needed
|
60 |
+
# gain_dB = torch.tensor([gain_dB])
|
61 |
+
# cutoff_freq = torch.tensor([cutoff_freq])
|
62 |
+
# q_factor = torch.tensor([q_factor])
|
63 |
+
|
64 |
+
A = 10 ** (gain_dB / 40.0)
|
65 |
+
w0 = 2 * math.pi * (cutoff_freq / sample_rate)
|
66 |
+
alpha = torch.sin(w0) / (2 * q_factor)
|
67 |
+
cos_w0 = torch.cos(w0)
|
68 |
+
sqrt_A = torch.sqrt(A)
|
69 |
+
|
70 |
+
if filter_type == "high_shelf":
|
71 |
+
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
72 |
+
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
|
73 |
+
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
74 |
+
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
75 |
+
a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
|
76 |
+
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
77 |
+
elif filter_type == "low_shelf":
|
78 |
+
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
79 |
+
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
|
80 |
+
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
81 |
+
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
82 |
+
a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
|
83 |
+
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
84 |
+
elif filter_type == "peaking":
|
85 |
+
b0 = 1 + alpha * A
|
86 |
+
b1 = -2 * cos_w0
|
87 |
+
b2 = 1 - alpha * A
|
88 |
+
a0 = 1 + (alpha / A)
|
89 |
+
a1 = -2 * cos_w0
|
90 |
+
a2 = 1 - (alpha / A)
|
91 |
+
else:
|
92 |
+
raise ValueError(f"Invalid filter_type: {filter_type}.")
|
93 |
+
|
94 |
+
b = torch.stack([b0, b1, b2], dim=0).view(-1)
|
95 |
+
a = torch.stack([a0, a1, a2], dim=0).view(-1)
|
96 |
+
|
97 |
+
# normalize
|
98 |
+
b = b.type_as(gain_dB) / a0
|
99 |
+
a = a.type_as(gain_dB) / a0
|
100 |
+
|
101 |
+
return b, a
|
102 |
+
|
103 |
+
|
104 |
+
def freqz(b, a, n_fft: int = 512):
|
105 |
+
|
106 |
+
B = torch.fft.rfft(b, n_fft)
|
107 |
+
A = torch.fft.rfft(a, n_fft)
|
108 |
+
|
109 |
+
H = B / A
|
110 |
+
|
111 |
+
return H
|
112 |
+
|
113 |
+
|
114 |
+
def freq_domain_filter(x, H, n_fft):
|
115 |
+
|
116 |
+
X = torch.fft.rfft(x, n_fft)
|
117 |
+
|
118 |
+
# move H to same device as input x
|
119 |
+
H = H.type_as(X)
|
120 |
+
|
121 |
+
Y = X * H
|
122 |
+
|
123 |
+
y = torch.fft.irfft(Y, n_fft)
|
124 |
+
|
125 |
+
return y
|
126 |
+
|
127 |
+
|
128 |
+
def approx_iir_filter(b, a, x):
|
129 |
+
"""Approimxate the application of an IIR filter.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
b (Tensor): The numerator coefficients.
|
133 |
+
|
134 |
+
"""
|
135 |
+
|
136 |
+
# round up to nearest power of 2 for FFT
|
137 |
+
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
|
138 |
+
|
139 |
+
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
|
140 |
+
n_fft = n_fft.int()
|
141 |
+
|
142 |
+
# move coefficients to same device as x
|
143 |
+
b = b.type_as(x).view(-1)
|
144 |
+
a = a.type_as(x).view(-1)
|
145 |
+
|
146 |
+
# compute complex response
|
147 |
+
H = freqz(b, a, n_fft=n_fft).view(-1)
|
148 |
+
|
149 |
+
# apply filter
|
150 |
+
y = freq_domain_filter(x, H, n_fft)
|
151 |
+
|
152 |
+
# crop
|
153 |
+
y = y[: x.shape[-1]]
|
154 |
+
|
155 |
+
return y
|
156 |
+
|
157 |
+
|
158 |
+
def approx_iir_filter_cascade(
|
159 |
+
b_s: List[torch.Tensor],
|
160 |
+
a_s: List[torch.Tensor],
|
161 |
+
x: torch.Tensor,
|
162 |
+
):
|
163 |
+
"""Apply a cascade of IIR filters.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
b (list[Tensor]): List of tensors of shape (3)
|
167 |
+
a (list[Tensor]): List of tensors of (3)
|
168 |
+
x (torch.Tensor): 1d Tensor.
|
169 |
+
"""
|
170 |
+
|
171 |
+
if len(b_s) != len(a_s):
|
172 |
+
raise RuntimeError(
|
173 |
+
f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}."
|
174 |
+
)
|
175 |
+
|
176 |
+
# round up to nearest power of 2 for FFT
|
177 |
+
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
|
178 |
+
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
|
179 |
+
n_fft = n_fft.int()
|
180 |
+
|
181 |
+
# this could be done in parallel
|
182 |
+
b = torch.stack(b_s, dim=0).type_as(x)
|
183 |
+
a = torch.stack(a_s, dim=0).type_as(x)
|
184 |
+
|
185 |
+
H = freqz(b, a, n_fft=n_fft)
|
186 |
+
H = torch.prod(H, dim=0).view(-1)
|
187 |
+
|
188 |
+
# apply filter
|
189 |
+
y = freq_domain_filter(x, H, n_fft)
|
190 |
+
|
191 |
+
# crop
|
192 |
+
y = y[: x.shape[-1]]
|
193 |
+
|
194 |
+
return y
|
deepafx_st/processors/dsp/compressor.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import scipy.signal
|
5 |
+
from numba import jit
|
6 |
+
|
7 |
+
from deepafx_st.processors.processor import Processor
|
8 |
+
|
9 |
+
|
10 |
+
# Adapted from: https://github.com/drscotthawley/signaltrain/blob/master/signaltrain/audio.py
|
11 |
+
@jit(nopython=True)
|
12 |
+
def my_clip_min(
|
13 |
+
x: np.ndarray,
|
14 |
+
clip_min: float,
|
15 |
+
): # does the work of np.clip(), which numba doesn't support yet
|
16 |
+
# TODO: keep an eye on Numba PR https://github.com/numba/numba/pull/3468 that fixes this
|
17 |
+
inds = np.where(x < clip_min)
|
18 |
+
x[inds] = clip_min
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
@jit(nopython=True)
|
23 |
+
def compressor(
|
24 |
+
x: np.ndarray,
|
25 |
+
sample_rate: float,
|
26 |
+
threshold: float = -24.0,
|
27 |
+
ratio: float = 2.0,
|
28 |
+
attack_time: float = 0.01,
|
29 |
+
release_time: float = 0.01,
|
30 |
+
knee_dB: float = 0.0,
|
31 |
+
makeup_gain_dB: float = 0.0,
|
32 |
+
dtype=np.float32,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (np.ndarray): Input signal.
|
38 |
+
sample_rate (float): Sample rate in Hz.
|
39 |
+
threshold (float): Threhold in dB.
|
40 |
+
ratio (float): Ratio (should be >=1 , i.e. ratio:1).
|
41 |
+
attack_time (float): Attack time in seconds.
|
42 |
+
release_time (float): Release time in seconds.
|
43 |
+
knee_dB (float): Knee.
|
44 |
+
makeup_gain_dB (float): Makeup Gain.
|
45 |
+
dtype (type): Output type. Default: np.float32
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
y (np.ndarray): Output signal.
|
49 |
+
|
50 |
+
"""
|
51 |
+
# print(f"dsp comp fs = {sample_rate}")
|
52 |
+
|
53 |
+
N = len(x)
|
54 |
+
dtype = x.dtype
|
55 |
+
y = np.zeros(N, dtype=dtype)
|
56 |
+
|
57 |
+
# Initialize separate attack and release times
|
58 |
+
# Where do these numbers come from
|
59 |
+
alpha_A = np.exp(-np.log(9) / (sample_rate * attack_time))
|
60 |
+
alpha_R = np.exp(-np.log(9) / (sample_rate * release_time))
|
61 |
+
|
62 |
+
# Turn the input signal into a uni-polar signal on the dB scale
|
63 |
+
x_G = 20 * np.log10(np.abs(x) + 1e-8) # x_uni casts type
|
64 |
+
|
65 |
+
# Ensure there are no values of negative infinity
|
66 |
+
x_G = my_clip_min(x_G, -96)
|
67 |
+
|
68 |
+
# Static characteristics with knee
|
69 |
+
y_G = np.zeros(N, dtype=dtype)
|
70 |
+
|
71 |
+
# Below knee
|
72 |
+
idx = np.where((2 * (x_G - threshold)) < -knee_dB)
|
73 |
+
y_G[idx] = x_G[idx]
|
74 |
+
|
75 |
+
# At knee
|
76 |
+
idx = np.where((2 * np.abs(x_G - threshold)) <= knee_dB)
|
77 |
+
y_G[idx] = x_G[idx] + (
|
78 |
+
(1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2)
|
79 |
+
) / (2 * knee_dB)
|
80 |
+
|
81 |
+
# Above knee threshold
|
82 |
+
idx = np.where((2 * (x_G - threshold)) > knee_dB)
|
83 |
+
y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio)
|
84 |
+
|
85 |
+
x_L = x_G - y_G
|
86 |
+
|
87 |
+
# this loop is slow but not vectorizable due to its cumulative, sequential nature. @autojit makes it fast(er).
|
88 |
+
y_L = np.zeros(N, dtype=dtype)
|
89 |
+
for n in range(1, N):
|
90 |
+
# smooth over the gainChange
|
91 |
+
if x_L[n] > y_L[n - 1]: # attack mode
|
92 |
+
y_L[n] = (alpha_A * y_L[n - 1]) + ((1 - alpha_A) * x_L[n])
|
93 |
+
else: # release
|
94 |
+
y_L[n] = (alpha_R * y_L[n - 1]) + ((1 - alpha_R) * x_L[n])
|
95 |
+
|
96 |
+
# Convert to linear amplitude scalar; i.e. map from dB to amplitude
|
97 |
+
lin_y_L = np.power(10.0, (-y_L / 20.0))
|
98 |
+
y = lin_y_L * x # Apply linear amplitude to input sample
|
99 |
+
|
100 |
+
y *= np.power(10.0, makeup_gain_dB / 20.0) # apply makeup gain
|
101 |
+
|
102 |
+
return y.astype(dtype)
|
103 |
+
|
104 |
+
|
105 |
+
class Compressor(Processor):
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
sample_rate,
|
109 |
+
max_threshold=0.0,
|
110 |
+
min_threshold=-80,
|
111 |
+
max_ratio=20.0,
|
112 |
+
min_ratio=1.0,
|
113 |
+
max_attack=0.1,
|
114 |
+
min_attack=0.0001,
|
115 |
+
max_release=1.0,
|
116 |
+
min_release=0.005,
|
117 |
+
max_knee=12.0,
|
118 |
+
min_knee=0.0,
|
119 |
+
max_mkgain=48.0,
|
120 |
+
min_mkgain=-48.0,
|
121 |
+
eps=1e-8,
|
122 |
+
):
|
123 |
+
""" """
|
124 |
+
super().__init__()
|
125 |
+
self.sample_rate = sample_rate
|
126 |
+
self.eps = eps
|
127 |
+
self.ports = [
|
128 |
+
{
|
129 |
+
"name": "Threshold",
|
130 |
+
"min": min_threshold,
|
131 |
+
"max": max_threshold,
|
132 |
+
"default": -12.0,
|
133 |
+
"units": "",
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"name": "Ratio",
|
137 |
+
"min": min_ratio,
|
138 |
+
"max": max_ratio,
|
139 |
+
"default": 2.0,
|
140 |
+
"units": "",
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "Attack Time",
|
144 |
+
"min": min_attack,
|
145 |
+
"max": max_attack,
|
146 |
+
"default": 0.001,
|
147 |
+
"units": "s",
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"name": "Release Time",
|
151 |
+
"min": min_release,
|
152 |
+
"max": max_release,
|
153 |
+
"default": 0.045,
|
154 |
+
"units": "s",
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"name": "Knee",
|
158 |
+
"min": min_knee,
|
159 |
+
"max": max_knee,
|
160 |
+
"default": 6.0,
|
161 |
+
"units": "dB",
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"name": "Makeup Gain",
|
165 |
+
"min": min_mkgain,
|
166 |
+
"max": max_mkgain,
|
167 |
+
"default": 0.0,
|
168 |
+
"units": "dB",
|
169 |
+
},
|
170 |
+
]
|
171 |
+
|
172 |
+
self.num_control_params = len(self.ports)
|
173 |
+
self.process_fn = compressor
|
174 |
+
|
175 |
+
def forward(self, x, p, sample_rate=24000, **kwargs):
|
176 |
+
"All processing in the forward is in numpy."
|
177 |
+
return self.run_series(x, p, sample_rate)
|
deepafx_st/processors/dsp/peq.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import scipy.signal
|
4 |
+
from numba import jit
|
5 |
+
|
6 |
+
from deepafx_st.processors.processor import Processor
|
7 |
+
|
8 |
+
|
9 |
+
@jit(nopython=True)
|
10 |
+
def biqaud(
|
11 |
+
gain_dB: float,
|
12 |
+
cutoff_freq: float,
|
13 |
+
q_factor: float,
|
14 |
+
sample_rate: float,
|
15 |
+
filter_type: str,
|
16 |
+
):
|
17 |
+
"""Use design parameters to generate coeffieicnets for a specific filter type.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
gain_dB (float): Shelving filter gain in dB.
|
21 |
+
cutoff_freq (float): Cutoff frequency in Hz.
|
22 |
+
q_factor (float): Q factor.
|
23 |
+
sample_rate (float): Sample rate in Hz.
|
24 |
+
filter_type (str): Filter type.
|
25 |
+
One of "low_shelf", "high_shelf", or "peaking"
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
|
29 |
+
a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
|
30 |
+
"""
|
31 |
+
|
32 |
+
A = 10 ** (gain_dB / 40.0)
|
33 |
+
w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
|
34 |
+
alpha = np.sin(w0) / (2.0 * q_factor)
|
35 |
+
|
36 |
+
cos_w0 = np.cos(w0)
|
37 |
+
sqrt_A = np.sqrt(A)
|
38 |
+
|
39 |
+
if filter_type == "high_shelf":
|
40 |
+
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
41 |
+
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
|
42 |
+
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
43 |
+
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
44 |
+
a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
|
45 |
+
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
46 |
+
elif filter_type == "low_shelf":
|
47 |
+
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
48 |
+
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
|
49 |
+
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
50 |
+
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
51 |
+
a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
|
52 |
+
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
53 |
+
elif filter_type == "peaking":
|
54 |
+
b0 = 1 + alpha * A
|
55 |
+
b1 = -2 * cos_w0
|
56 |
+
b2 = 1 - alpha * A
|
57 |
+
a0 = 1 + alpha / A
|
58 |
+
a1 = -2 * cos_w0
|
59 |
+
a2 = 1 - alpha / A
|
60 |
+
else:
|
61 |
+
pass
|
62 |
+
# raise ValueError(f"Invalid filter_type: {filter_type}.")
|
63 |
+
|
64 |
+
b = np.array([b0, b1, b2]) / a0
|
65 |
+
a = np.array([a0, a1, a2]) / a0
|
66 |
+
|
67 |
+
return b, a
|
68 |
+
|
69 |
+
|
70 |
+
# Adapted from https://github.com/csteinmetz1/pyloudnorm/blob/master/pyloudnorm/iirfilter.py
|
71 |
+
def parametric_eq(
|
72 |
+
x: np.ndarray,
|
73 |
+
sample_rate: float,
|
74 |
+
low_shelf_gain_dB: float = 0.0,
|
75 |
+
low_shelf_cutoff_freq: float = 80.0,
|
76 |
+
low_shelf_q_factor: float = 0.707,
|
77 |
+
first_band_gain_dB: float = 0.0,
|
78 |
+
first_band_cutoff_freq: float = 300.0,
|
79 |
+
first_band_q_factor: float = 0.707,
|
80 |
+
second_band_gain_dB: float = 0.0,
|
81 |
+
second_band_cutoff_freq: float = 1000.0,
|
82 |
+
second_band_q_factor: float = 0.707,
|
83 |
+
third_band_gain_dB: float = 0.0,
|
84 |
+
third_band_cutoff_freq: float = 4000.0,
|
85 |
+
third_band_q_factor: float = 0.707,
|
86 |
+
fourth_band_gain_dB: float = 0.0,
|
87 |
+
fourth_band_cutoff_freq: float = 8000.0,
|
88 |
+
fourth_band_q_factor: float = 0.707,
|
89 |
+
high_shelf_gain_dB: float = 0.0,
|
90 |
+
high_shelf_cutoff_freq: float = 1000.0,
|
91 |
+
high_shelf_q_factor: float = 0.707,
|
92 |
+
dtype=np.float32,
|
93 |
+
):
|
94 |
+
"""Six-band parametric EQ.
|
95 |
+
|
96 |
+
Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf
|
97 |
+
|
98 |
+
Args:
|
99 |
+
|
100 |
+
|
101 |
+
"""
|
102 |
+
# print(f"autodiff peq fs = {sample_rate}")
|
103 |
+
|
104 |
+
# -------- apply low-shelf filter --------
|
105 |
+
b, a = biqaud(
|
106 |
+
low_shelf_gain_dB,
|
107 |
+
low_shelf_cutoff_freq,
|
108 |
+
low_shelf_q_factor,
|
109 |
+
sample_rate,
|
110 |
+
"low_shelf",
|
111 |
+
)
|
112 |
+
sos0 = np.concatenate((b, a))
|
113 |
+
x = scipy.signal.lfilter(b, a, x)
|
114 |
+
|
115 |
+
# -------- apply first-band peaking filter --------
|
116 |
+
b, a = biqaud(
|
117 |
+
first_band_gain_dB,
|
118 |
+
first_band_cutoff_freq,
|
119 |
+
first_band_q_factor,
|
120 |
+
sample_rate,
|
121 |
+
"peaking",
|
122 |
+
)
|
123 |
+
sos1 = np.concatenate((b, a))
|
124 |
+
x = scipy.signal.lfilter(b, a, x)
|
125 |
+
|
126 |
+
# -------- apply second-band peaking filter --------
|
127 |
+
b, a = biqaud(
|
128 |
+
second_band_gain_dB,
|
129 |
+
second_band_cutoff_freq,
|
130 |
+
second_band_q_factor,
|
131 |
+
sample_rate,
|
132 |
+
"peaking",
|
133 |
+
)
|
134 |
+
sos2 = np.concatenate((b, a))
|
135 |
+
x = scipy.signal.lfilter(b, a, x)
|
136 |
+
|
137 |
+
# -------- apply third-band peaking filter --------
|
138 |
+
b, a = biqaud(
|
139 |
+
third_band_gain_dB,
|
140 |
+
third_band_cutoff_freq,
|
141 |
+
third_band_q_factor,
|
142 |
+
sample_rate,
|
143 |
+
"peaking",
|
144 |
+
)
|
145 |
+
sos3 = np.concatenate((b, a))
|
146 |
+
x = scipy.signal.lfilter(b, a, x)
|
147 |
+
|
148 |
+
# -------- apply fourth-band peaking filter --------
|
149 |
+
b, a = biqaud(
|
150 |
+
fourth_band_gain_dB,
|
151 |
+
fourth_band_cutoff_freq,
|
152 |
+
fourth_band_q_factor,
|
153 |
+
sample_rate,
|
154 |
+
"peaking",
|
155 |
+
)
|
156 |
+
sos4 = np.concatenate((b, a))
|
157 |
+
x = scipy.signal.lfilter(b, a, x)
|
158 |
+
|
159 |
+
# -------- apply high-shelf filter --------
|
160 |
+
b, a = biqaud(
|
161 |
+
high_shelf_gain_dB,
|
162 |
+
high_shelf_cutoff_freq,
|
163 |
+
high_shelf_q_factor,
|
164 |
+
sample_rate,
|
165 |
+
"high_shelf",
|
166 |
+
)
|
167 |
+
sos5 = np.concatenate((b, a))
|
168 |
+
x = scipy.signal.lfilter(b, a, x)
|
169 |
+
|
170 |
+
return x.astype(dtype)
|
171 |
+
|
172 |
+
|
173 |
+
class ParametricEQ(Processor):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
sample_rate,
|
177 |
+
min_gain_dB=-24.0,
|
178 |
+
default_gain_dB=0.0,
|
179 |
+
max_gain_dB=24.0,
|
180 |
+
min_q_factor=0.1,
|
181 |
+
default_q_factor=0.707,
|
182 |
+
max_q_factor=10,
|
183 |
+
eps=1e-8,
|
184 |
+
):
|
185 |
+
""" """
|
186 |
+
super().__init__()
|
187 |
+
self.sample_rate = sample_rate
|
188 |
+
self.eps = eps
|
189 |
+
self.ports = [
|
190 |
+
{
|
191 |
+
"name": "Lowshelf gain",
|
192 |
+
"min": min_gain_dB,
|
193 |
+
"max": max_gain_dB,
|
194 |
+
"default": default_gain_dB,
|
195 |
+
"units": "dB",
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"name": "Lowshelf cutoff",
|
199 |
+
"min": 20.0,
|
200 |
+
"max": 200.0,
|
201 |
+
"default": 100.0,
|
202 |
+
"units": "Hz",
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"name": "Lowshelf Q",
|
206 |
+
"min": min_q_factor,
|
207 |
+
"max": max_q_factor,
|
208 |
+
"default": default_q_factor,
|
209 |
+
"units": "",
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"name": "First band gain",
|
213 |
+
"min": min_gain_dB,
|
214 |
+
"max": max_gain_dB,
|
215 |
+
"default": default_gain_dB,
|
216 |
+
"units": "dB",
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"name": "First band cutoff",
|
220 |
+
"min": 200.0,
|
221 |
+
"max": 2000.0,
|
222 |
+
"default": 400.0,
|
223 |
+
"units": "Hz",
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"name": "First band Q",
|
227 |
+
"min": min_q_factor,
|
228 |
+
"max": max_q_factor,
|
229 |
+
"default": 0.707,
|
230 |
+
"units": "",
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"name": "Second band gain",
|
234 |
+
"min": min_gain_dB,
|
235 |
+
"max": max_gain_dB,
|
236 |
+
"default": default_gain_dB,
|
237 |
+
"units": "dB",
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"name": "Second band cutoff",
|
241 |
+
"min": 800.0,
|
242 |
+
"max": 4000.0,
|
243 |
+
"default": 1000.0,
|
244 |
+
"units": "Hz",
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"name": "Second band Q",
|
248 |
+
"min": min_q_factor,
|
249 |
+
"max": max_q_factor,
|
250 |
+
"default": default_q_factor,
|
251 |
+
"units": "",
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"name": "Third band gain",
|
255 |
+
"min": min_gain_dB,
|
256 |
+
"max": max_gain_dB,
|
257 |
+
"default": default_gain_dB,
|
258 |
+
"units": "dB",
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"name": "Third band cutoff",
|
262 |
+
"min": 2000.0,
|
263 |
+
"max": 8000.0,
|
264 |
+
"default": 4000.0,
|
265 |
+
"units": "Hz",
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"name": "Third band Q",
|
269 |
+
"min": min_q_factor,
|
270 |
+
"max": max_q_factor,
|
271 |
+
"default": default_q_factor,
|
272 |
+
"units": "",
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"name": "Fourth band gain",
|
276 |
+
"min": min_gain_dB,
|
277 |
+
"max": max_gain_dB,
|
278 |
+
"default": default_gain_dB,
|
279 |
+
"units": "dB",
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"name": "Fourth band cutoff",
|
283 |
+
"min": 4000.0,
|
284 |
+
"max": (24000 // 2) * 0.9,
|
285 |
+
"default": 8000.0,
|
286 |
+
"units": "Hz",
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"name": "Fourth band Q",
|
290 |
+
"min": min_q_factor,
|
291 |
+
"max": max_q_factor,
|
292 |
+
"default": default_q_factor,
|
293 |
+
"units": "",
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"name": "Highshelf gain",
|
297 |
+
"min": min_gain_dB,
|
298 |
+
"max": max_gain_dB,
|
299 |
+
"default": default_gain_dB,
|
300 |
+
"units": "dB",
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"name": "Highshelf cutoff",
|
304 |
+
"min": 4000.0,
|
305 |
+
"max": (24000 // 2) * 0.9,
|
306 |
+
"default": 8000.0,
|
307 |
+
"units": "Hz",
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"name": "Highshelf Q",
|
311 |
+
"min": min_q_factor,
|
312 |
+
"max": max_q_factor,
|
313 |
+
"default": default_q_factor,
|
314 |
+
"units": "",
|
315 |
+
},
|
316 |
+
]
|
317 |
+
|
318 |
+
self.num_control_params = len(self.ports)
|
319 |
+
self.process_fn = parametric_eq
|
320 |
+
|
321 |
+
def forward(self, x, p, sample_rate=24000, **kwargs):
|
322 |
+
"All processing in the forward is in numpy."
|
323 |
+
return self.run_series(x, p, sample_rate)
|
deepafx_st/processors/processor.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import multiprocessing
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
import deepafx_st.utils as utils
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Processor(torch.nn.Module, ABC):
|
9 |
+
"""Processor base class."""
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
def denormalize_params(self, p):
|
17 |
+
"""This method takes a tensor of parameters scaled from 0-1 and
|
18 |
+
restores them back to the original parameter range."""
|
19 |
+
|
20 |
+
# check if the number of parameters is correct
|
21 |
+
params = p # torch.split(p, 1, -1)
|
22 |
+
if len(params) != self.num_control_params:
|
23 |
+
raise RuntimeError(
|
24 |
+
f"Invalid number of parameters. ",
|
25 |
+
f"Expected {self.num_control_params} but found {len(params)} {params.shape}.",
|
26 |
+
)
|
27 |
+
|
28 |
+
# iterate over the parameters and expand from 0-1 to full range
|
29 |
+
denorm_params = []
|
30 |
+
for param, port in zip(params, self.ports):
|
31 |
+
# check if parameter exceeds range
|
32 |
+
if param > 1.0 or param < 0.0:
|
33 |
+
raise RuntimeError(
|
34 |
+
f"""Parameter '{port["name"]}' exceeds range: {param}"""
|
35 |
+
)
|
36 |
+
|
37 |
+
# denormalize and store result
|
38 |
+
denorm_params.append(utils.denormalize(param, port["max"], port["min"]))
|
39 |
+
|
40 |
+
return denorm_params
|
41 |
+
|
42 |
+
def normalize_params(self, *params):
|
43 |
+
"""This method creates a vector of parameters normalized from 0-1."""
|
44 |
+
|
45 |
+
# check if the number of parameters is correct
|
46 |
+
if len(params) != self.num_control_params:
|
47 |
+
raise RuntimeError(
|
48 |
+
f"Invalid number of parameters. ",
|
49 |
+
f"Expected {self.num_control_params} but found {len(params)}.",
|
50 |
+
)
|
51 |
+
|
52 |
+
norm_params = []
|
53 |
+
for param, port in zip(params, self.ports):
|
54 |
+
norm_params.append(utils.normalize(param, port["max"], port["min"]))
|
55 |
+
|
56 |
+
p = torch.tensor(norm_params).view(1, -1)
|
57 |
+
|
58 |
+
return p
|
59 |
+
|
60 |
+
# def run_series(self, inputs, params):
|
61 |
+
# """Run the process function in a loop given a list of inputs and parameters"""
|
62 |
+
# p_b_denorm = [p for p in self.denormalize_params(params)]
|
63 |
+
# y = self.process_fn(inputs, self.sample_rate, *p_b_denorm)
|
64 |
+
# return y
|
65 |
+
|
66 |
+
def run_series(self, inputs, params, sample_rate=24000):
|
67 |
+
"""Run the process function in a loop given a list of inputs and parameters"""
|
68 |
+
if params.ndim == 1:
|
69 |
+
params = np.reshape(params, (1, -1))
|
70 |
+
inputs = np.reshape(inputs, (1, -1))
|
71 |
+
bs = inputs.shape[0]
|
72 |
+
ys = []
|
73 |
+
params = np.clip(params, 0, 1)
|
74 |
+
for bidx in range(bs):
|
75 |
+
p_b_denorm = [p for p in self.denormalize_params(params[bidx, :])]
|
76 |
+
y = self.process_fn(
|
77 |
+
inputs[bidx, ...].reshape(-1),
|
78 |
+
sample_rate,
|
79 |
+
*p_b_denorm,
|
80 |
+
)
|
81 |
+
ys.append(y)
|
82 |
+
y = np.stack(ys, axis=0)
|
83 |
+
return y
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def forward(self, x, p):
|
87 |
+
pass
|
deepafx_st/processors/proxy/channel.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from deepafx_st.processors.proxy.proxy_system import ProxySystem
|
3 |
+
from deepafx_st.utils import DSPMode
|
4 |
+
|
5 |
+
|
6 |
+
class ProxyChannel(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
proxy_system_ckpts: list,
|
10 |
+
freeze_proxies: bool = True,
|
11 |
+
dsp_mode: DSPMode = DSPMode.NONE,
|
12 |
+
num_tcns: int = 2,
|
13 |
+
tcn_nblocks: int = 4,
|
14 |
+
tcn_dilation_growth: int = 8,
|
15 |
+
tcn_channel_width: int = 64,
|
16 |
+
tcn_kernel_size: int = 13,
|
17 |
+
sample_rate: int = 24000,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.freeze_proxies = freeze_proxies
|
21 |
+
self.dsp_mode = dsp_mode
|
22 |
+
self.num_tcns = num_tcns
|
23 |
+
|
24 |
+
# load the proxies
|
25 |
+
self.proxies = torch.nn.ModuleList()
|
26 |
+
self.num_control_params = 0
|
27 |
+
self.ports = []
|
28 |
+
for proxy_system_ckpt in proxy_system_ckpts:
|
29 |
+
proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt)
|
30 |
+
# freeze model parameters
|
31 |
+
if freeze_proxies:
|
32 |
+
for param in proxy.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
self.proxies.append(proxy)
|
35 |
+
if proxy.hparams.processor == "channel":
|
36 |
+
self.ports = proxy.processor.ports
|
37 |
+
else:
|
38 |
+
self.ports.append(proxy.processor.ports)
|
39 |
+
self.num_control_params += proxy.processor.num_control_params
|
40 |
+
|
41 |
+
if len(proxy_system_ckpts) == 0:
|
42 |
+
if self.num_tcns == 2:
|
43 |
+
peq_proxy = ProxySystem(
|
44 |
+
processor="peq",
|
45 |
+
output_gain=False,
|
46 |
+
nblocks=tcn_nblocks,
|
47 |
+
dilation_growth=tcn_dilation_growth,
|
48 |
+
kernel_size=tcn_kernel_size,
|
49 |
+
channel_width=tcn_channel_width,
|
50 |
+
sample_rate=sample_rate,
|
51 |
+
)
|
52 |
+
self.proxies.append(peq_proxy)
|
53 |
+
self.ports.append(peq_proxy.processor.ports)
|
54 |
+
self.num_control_params += peq_proxy.processor.num_control_params
|
55 |
+
comp_proxy = ProxySystem(
|
56 |
+
processor="comp",
|
57 |
+
output_gain=True,
|
58 |
+
nblocks=tcn_nblocks,
|
59 |
+
dilation_growth=tcn_dilation_growth,
|
60 |
+
kernel_size=tcn_kernel_size,
|
61 |
+
channel_width=tcn_channel_width,
|
62 |
+
sample_rate=sample_rate,
|
63 |
+
)
|
64 |
+
self.proxies.append(comp_proxy)
|
65 |
+
self.ports.append(comp_proxy.processor.ports)
|
66 |
+
self.num_control_params += comp_proxy.processor.num_control_params
|
67 |
+
elif self.num_tcns == 1:
|
68 |
+
channel_proxy = ProxySystem(
|
69 |
+
processor="channel",
|
70 |
+
output_gain=True,
|
71 |
+
nblocks=tcn_nblocks,
|
72 |
+
dilation_growth=tcn_dilation_growth,
|
73 |
+
kernel_size=tcn_kernel_size,
|
74 |
+
channel_width=tcn_channel_width,
|
75 |
+
sample_rate=sample_rate,
|
76 |
+
)
|
77 |
+
self.proxies.append(channel_proxy)
|
78 |
+
for port_list in channel_proxy.processor.ports:
|
79 |
+
self.ports.append(port_list)
|
80 |
+
self.num_control_params += channel_proxy.processor.num_control_params
|
81 |
+
else:
|
82 |
+
raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.")
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
x: torch.Tensor,
|
87 |
+
p: torch.Tensor,
|
88 |
+
dsp_mode: DSPMode = DSPMode.NONE,
|
89 |
+
sample_rate: int = 24000,
|
90 |
+
**kwargs,
|
91 |
+
):
|
92 |
+
# loop over the proxies and pass parameters
|
93 |
+
stop_idx = 0
|
94 |
+
for proxy in self.proxies:
|
95 |
+
start_idx = stop_idx
|
96 |
+
stop_idx += proxy.processor.num_control_params
|
97 |
+
p_subset = p[:, start_idx:stop_idx]
|
98 |
+
if dsp_mode.name == DSPMode.NONE.name:
|
99 |
+
x = proxy(
|
100 |
+
x,
|
101 |
+
p_subset,
|
102 |
+
use_dsp=False,
|
103 |
+
)
|
104 |
+
elif dsp_mode.name == DSPMode.INFER.name:
|
105 |
+
x = proxy(
|
106 |
+
x,
|
107 |
+
p_subset,
|
108 |
+
use_dsp=True,
|
109 |
+
sample_rate=sample_rate,
|
110 |
+
)
|
111 |
+
elif dsp_mode.name == DSPMode.TRAIN_INFER.name:
|
112 |
+
# Mimic gumbel softmax implementation to replace grads similar to
|
113 |
+
# https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
|
114 |
+
x_hard = proxy(
|
115 |
+
x,
|
116 |
+
p_subset,
|
117 |
+
use_dsp=True,
|
118 |
+
sample_rate=sample_rate,
|
119 |
+
)
|
120 |
+
x = proxy(
|
121 |
+
x,
|
122 |
+
p_subset,
|
123 |
+
use_dsp=False,
|
124 |
+
sample_rate=sample_rate,
|
125 |
+
)
|
126 |
+
x = (x_hard - x).detach() + x
|
127 |
+
else:
|
128 |
+
assert 0, "invalid dsp model for proxy"
|
129 |
+
|
130 |
+
return x
|
deepafx_st/processors/proxy/proxy_system.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from re import X
|
2 |
+
import torch
|
3 |
+
import auraloss
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from typing import Tuple, List, Dict
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
|
9 |
+
import deepafx_st.utils as utils
|
10 |
+
from deepafx_st.data.proxy import DSPProxyDataset
|
11 |
+
from deepafx_st.processors.proxy.tcn import ConditionalTCN
|
12 |
+
from deepafx_st.processors.spsa.channel import SPSAChannel
|
13 |
+
from deepafx_st.processors.dsp.peq import ParametricEQ
|
14 |
+
from deepafx_st.processors.dsp.compressor import Compressor
|
15 |
+
|
16 |
+
|
17 |
+
class ProxySystem(pl.LightningModule):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
causal=True,
|
21 |
+
nblocks=4,
|
22 |
+
dilation_growth=8,
|
23 |
+
kernel_size=13,
|
24 |
+
channel_width=64,
|
25 |
+
input_dir=None,
|
26 |
+
processor="channel",
|
27 |
+
batch_size=32,
|
28 |
+
lr=3e-4,
|
29 |
+
lr_patience=20,
|
30 |
+
patience=10,
|
31 |
+
preload=False,
|
32 |
+
sample_rate=24000,
|
33 |
+
shuffle=True,
|
34 |
+
train_length=65536,
|
35 |
+
train_examples_per_epoch=10000,
|
36 |
+
val_length=131072,
|
37 |
+
val_examples_per_epoch=1000,
|
38 |
+
num_workers=16,
|
39 |
+
output_gain=False,
|
40 |
+
**kwargs,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.save_hyperparameters()
|
44 |
+
#print(f"Proxy Processor: {processor} @ fs={sample_rate} Hz")
|
45 |
+
|
46 |
+
# construct both the true DSP...
|
47 |
+
if self.hparams.processor == "peq":
|
48 |
+
self.processor = ParametricEQ(self.hparams.sample_rate)
|
49 |
+
elif self.hparams.processor == "comp":
|
50 |
+
self.processor = Compressor(self.hparams.sample_rate)
|
51 |
+
elif self.hparams.processor == "channel":
|
52 |
+
self.processor = SPSAChannel(self.hparams.sample_rate)
|
53 |
+
|
54 |
+
# and the neural network proxy
|
55 |
+
self.proxy = ConditionalTCN(
|
56 |
+
self.hparams.sample_rate,
|
57 |
+
num_control_params=self.processor.num_control_params,
|
58 |
+
causal=self.hparams.causal,
|
59 |
+
nblocks=self.hparams.nblocks,
|
60 |
+
channel_width=self.hparams.channel_width,
|
61 |
+
kernel_size=self.hparams.kernel_size,
|
62 |
+
dilation_growth=self.hparams.dilation_growth,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.receptive_field = self.proxy.compute_receptive_field()
|
66 |
+
|
67 |
+
self.recon_losses = {}
|
68 |
+
self.recon_loss_weights = {}
|
69 |
+
|
70 |
+
self.recon_losses["mrstft"] = auraloss.freq.MultiResolutionSTFTLoss(
|
71 |
+
fft_sizes=[32, 128, 512, 2048, 8192, 32768],
|
72 |
+
hop_sizes=[16, 64, 256, 1024, 4096, 16384],
|
73 |
+
win_lengths=[32, 128, 512, 2048, 8192, 32768],
|
74 |
+
w_sc=0.0,
|
75 |
+
w_phs=0.0,
|
76 |
+
w_lin_mag=1.0,
|
77 |
+
w_log_mag=1.0,
|
78 |
+
)
|
79 |
+
self.recon_loss_weights["mrstft"] = 1.0
|
80 |
+
|
81 |
+
self.recon_losses["l1"] = torch.nn.L1Loss()
|
82 |
+
self.recon_loss_weights["l1"] = 100.0
|
83 |
+
|
84 |
+
def forward(self, x, p, use_dsp=False, sample_rate=24000, **kwargs):
|
85 |
+
"""Use the pre-trained neural network proxy effect."""
|
86 |
+
bs, chs, samp = x.size()
|
87 |
+
if not use_dsp:
|
88 |
+
y = self.proxy(x, p)
|
89 |
+
# manually apply the makeup gain parameter
|
90 |
+
if self.hparams.output_gain and not self.hparams.processor == "peq":
|
91 |
+
gain_db = (p[..., -1] * 96) - 48
|
92 |
+
gain_ln = 10 ** (gain_db / 20.0)
|
93 |
+
y *= gain_ln.view(bs, chs, 1)
|
94 |
+
else:
|
95 |
+
with torch.no_grad():
|
96 |
+
bs, chs, s = x.shape
|
97 |
+
|
98 |
+
if self.hparams.output_gain and not self.hparams.processor == "peq":
|
99 |
+
# override makeup gain
|
100 |
+
gain_db = (p[..., -1] * 96) - 48
|
101 |
+
gain_ln = 10 ** (gain_db / 20.0)
|
102 |
+
p[..., -1] = 0.5
|
103 |
+
|
104 |
+
if self.hparams.processor == "channel":
|
105 |
+
y_temp = self.processor(x.cpu(), p.cpu())
|
106 |
+
y_temp = y_temp.view(bs, chs, s).type_as(x)
|
107 |
+
else:
|
108 |
+
y_temp = self.processor(
|
109 |
+
x.cpu().numpy(),
|
110 |
+
p.cpu().numpy(),
|
111 |
+
sample_rate,
|
112 |
+
)
|
113 |
+
y_temp = torch.tensor(y_temp).view(bs, chs, s).type_as(x)
|
114 |
+
|
115 |
+
y = y_temp.type_as(x).view(bs, 1, -1)
|
116 |
+
|
117 |
+
if self.hparams.output_gain and not self.hparams.processor == "peq":
|
118 |
+
y *= gain_ln.view(bs, chs, 1)
|
119 |
+
|
120 |
+
return y
|
121 |
+
|
122 |
+
def common_step(
|
123 |
+
self,
|
124 |
+
batch: Tuple,
|
125 |
+
batch_idx: int,
|
126 |
+
optimizer_idx: int = 0,
|
127 |
+
train: bool = True,
|
128 |
+
):
|
129 |
+
loss = 0
|
130 |
+
x, y, p = batch
|
131 |
+
|
132 |
+
y_hat = self(x, p)
|
133 |
+
|
134 |
+
# compute loss
|
135 |
+
for loss_idx, (loss_name, loss_fn) in enumerate(self.recon_losses.items()):
|
136 |
+
tmp_loss = loss_fn(y_hat.float(), y.float())
|
137 |
+
loss += self.recon_loss_weights[loss_name] * tmp_loss
|
138 |
+
|
139 |
+
self.log(
|
140 |
+
f"train_loss/{loss_name}" if train else f"val_loss/{loss_name}",
|
141 |
+
tmp_loss,
|
142 |
+
on_step=True,
|
143 |
+
on_epoch=True,
|
144 |
+
prog_bar=False,
|
145 |
+
logger=True,
|
146 |
+
sync_dist=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
if not train:
|
150 |
+
# store audio data
|
151 |
+
data_dict = {
|
152 |
+
"x": x.float().cpu(),
|
153 |
+
"y": y.float().cpu(),
|
154 |
+
"p": p.float().cpu(),
|
155 |
+
"y_hat": y_hat.float().cpu(),
|
156 |
+
}
|
157 |
+
else:
|
158 |
+
data_dict = {}
|
159 |
+
|
160 |
+
self.log(
|
161 |
+
"train_loss" if train else "val_loss",
|
162 |
+
loss,
|
163 |
+
on_step=True,
|
164 |
+
on_epoch=True,
|
165 |
+
prog_bar=False,
|
166 |
+
logger=True,
|
167 |
+
sync_dist=True,
|
168 |
+
)
|
169 |
+
|
170 |
+
return loss, data_dict
|
171 |
+
|
172 |
+
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
173 |
+
loss, _ = self.common_step(batch, batch_idx)
|
174 |
+
return loss
|
175 |
+
|
176 |
+
def validation_step(self, batch, batch_idx):
|
177 |
+
loss, data_dict = self.common_step(batch, batch_idx, train=False)
|
178 |
+
|
179 |
+
if batch_idx == 0:
|
180 |
+
return data_dict
|
181 |
+
|
182 |
+
def configure_optimizers(self):
|
183 |
+
optimizer = torch.optim.Adam(
|
184 |
+
self.proxy.parameters(),
|
185 |
+
lr=self.hparams.lr,
|
186 |
+
betas=(0.9, 0.999),
|
187 |
+
)
|
188 |
+
|
189 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
190 |
+
optimizer,
|
191 |
+
patience=self.hparams.lr_patience,
|
192 |
+
verbose=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}
|
196 |
+
|
197 |
+
def train_dataloader(self):
|
198 |
+
|
199 |
+
train_dataset = DSPProxyDataset(
|
200 |
+
self.hparams.input_dir,
|
201 |
+
self.processor,
|
202 |
+
self.hparams.processor, # name
|
203 |
+
subset="train",
|
204 |
+
length=self.hparams.train_length,
|
205 |
+
num_examples_per_epoch=self.hparams.train_examples_per_epoch,
|
206 |
+
half=True if self.hparams.precision == 16 else False,
|
207 |
+
buffer_size_gb=self.hparams.buffer_size_gb,
|
208 |
+
buffer_reload_rate=self.hparams.buffer_reload_rate,
|
209 |
+
)
|
210 |
+
|
211 |
+
g = torch.Generator()
|
212 |
+
g.manual_seed(0)
|
213 |
+
|
214 |
+
return torch.utils.data.DataLoader(
|
215 |
+
train_dataset,
|
216 |
+
num_workers=self.hparams.num_workers,
|
217 |
+
batch_size=self.hparams.batch_size,
|
218 |
+
worker_init_fn=utils.seed_worker,
|
219 |
+
generator=g,
|
220 |
+
pin_memory=True,
|
221 |
+
)
|
222 |
+
|
223 |
+
def val_dataloader(self):
|
224 |
+
|
225 |
+
val_dataset = DSPProxyDataset(
|
226 |
+
self.hparams.input_dir,
|
227 |
+
self.processor,
|
228 |
+
self.hparams.processor, # name
|
229 |
+
subset="val",
|
230 |
+
length=self.hparams.val_length,
|
231 |
+
num_examples_per_epoch=self.hparams.val_examples_per_epoch,
|
232 |
+
half=True if self.hparams.precision == 16 else False,
|
233 |
+
buffer_size_gb=self.hparams.buffer_size_gb,
|
234 |
+
buffer_reload_rate=self.hparams.buffer_reload_rate,
|
235 |
+
)
|
236 |
+
|
237 |
+
g = torch.Generator()
|
238 |
+
g.manual_seed(0)
|
239 |
+
|
240 |
+
return torch.utils.data.DataLoader(
|
241 |
+
val_dataset,
|
242 |
+
num_workers=self.hparams.num_workers,
|
243 |
+
batch_size=self.hparams.batch_size,
|
244 |
+
worker_init_fn=utils.seed_worker,
|
245 |
+
generator=g,
|
246 |
+
pin_memory=True,
|
247 |
+
)
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def count_control_params(plugin_config):
|
251 |
+
num_control_params = 0
|
252 |
+
|
253 |
+
for plugin in plugin_config["plugins"]:
|
254 |
+
for port in plugin["ports"]:
|
255 |
+
if port["optim"]:
|
256 |
+
num_control_params += 1
|
257 |
+
|
258 |
+
return num_control_params
|
259 |
+
|
260 |
+
# add any model hyperparameters here
|
261 |
+
@staticmethod
|
262 |
+
def add_model_specific_args(parent_parser):
|
263 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
264 |
+
# --- Model ---
|
265 |
+
parser.add_argument("--causal", action="store_true")
|
266 |
+
parser.add_argument("--output_gain", action="store_true")
|
267 |
+
parser.add_argument("--dilation_growth", type=int, default=8)
|
268 |
+
parser.add_argument("--nblocks", type=int, default=4)
|
269 |
+
parser.add_argument("--kernel_size", type=int, default=13)
|
270 |
+
parser.add_argument("--channel_width", type=int, default=13)
|
271 |
+
# --- Training ---
|
272 |
+
parser.add_argument("--input_dir", type=str)
|
273 |
+
parser.add_argument("--processor", type=str)
|
274 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
275 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
276 |
+
parser.add_argument("--lr_patience", type=int, default=20)
|
277 |
+
parser.add_argument("--patience", type=int, default=10)
|
278 |
+
parser.add_argument("--preload", action="store_true")
|
279 |
+
parser.add_argument("--sample_rate", type=int, default=24000)
|
280 |
+
parser.add_argument("--shuffle", type=bool, default=True)
|
281 |
+
parser.add_argument("--train_length", type=int, default=65536)
|
282 |
+
parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
|
283 |
+
parser.add_argument("--val_length", type=int, default=131072)
|
284 |
+
parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
|
285 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
286 |
+
parser.add_argument("--buffer_reload_rate", type=int, default=1000)
|
287 |
+
parser.add_argument("--buffer_size_gb", type=float, default=1.0)
|
288 |
+
|
289 |
+
return parser
|
deepafx_st/processors/proxy/tcn.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Christian J. Steinmetz
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# TCN implementation adapted from:
|
16 |
+
# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/tcn.py
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from argparse import ArgumentParser
|
20 |
+
|
21 |
+
from deepafx_st.utils import center_crop, causal_crop
|
22 |
+
|
23 |
+
|
24 |
+
class FiLM(torch.nn.Module):
|
25 |
+
def __init__(self, num_features, cond_dim):
|
26 |
+
super().__init__()
|
27 |
+
self.num_features = num_features
|
28 |
+
self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
|
29 |
+
self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)
|
30 |
+
|
31 |
+
def forward(self, x, cond):
|
32 |
+
|
33 |
+
# project conditioning to 2 x num. conv channels
|
34 |
+
cond = self.adaptor(cond)
|
35 |
+
|
36 |
+
# split the projection into gain and bias
|
37 |
+
g, b = torch.chunk(cond, 2, dim=-1)
|
38 |
+
|
39 |
+
# add virtual channel dim if needed
|
40 |
+
if g.ndim == 2:
|
41 |
+
g = g.unsqueeze(1)
|
42 |
+
b = b.unsqueeze(1)
|
43 |
+
|
44 |
+
# reshape for application
|
45 |
+
g = g.permute(0, 2, 1)
|
46 |
+
b = b.permute(0, 2, 1)
|
47 |
+
|
48 |
+
x = self.bn(x) # apply BatchNorm without affine
|
49 |
+
x = (x * g) + b # then apply conditional affine
|
50 |
+
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class ConditionalTCNBlock(torch.nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self, in_ch, out_ch, cond_dim, kernel_size=3, dilation=1, causal=False, **kwargs
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.in_ch = in_ch
|
61 |
+
self.out_ch = out_ch
|
62 |
+
self.kernel_size = kernel_size
|
63 |
+
self.dilation = dilation
|
64 |
+
self.causal = causal
|
65 |
+
|
66 |
+
self.conv1 = torch.nn.Conv1d(
|
67 |
+
in_ch,
|
68 |
+
out_ch,
|
69 |
+
kernel_size=kernel_size,
|
70 |
+
padding=0,
|
71 |
+
dilation=dilation,
|
72 |
+
bias=True,
|
73 |
+
)
|
74 |
+
self.film = FiLM(out_ch, cond_dim)
|
75 |
+
self.relu = torch.nn.PReLU(out_ch)
|
76 |
+
self.res = torch.nn.Conv1d(
|
77 |
+
in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x, p):
|
81 |
+
x_in = x
|
82 |
+
|
83 |
+
x = self.conv1(x)
|
84 |
+
x = self.film(x, p) # apply FiLM conditioning
|
85 |
+
x = self.relu(x)
|
86 |
+
x_res = self.res(x_in)
|
87 |
+
|
88 |
+
if self.causal:
|
89 |
+
x = x + causal_crop(x_res, x.shape[-1])
|
90 |
+
else:
|
91 |
+
x = x + center_crop(x_res, x.shape[-1])
|
92 |
+
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
class ConditionalTCN(torch.nn.Module):
|
97 |
+
"""Temporal convolutional network with conditioning module.
|
98 |
+
Args:
|
99 |
+
sample_rate (float): Audio sample rate.
|
100 |
+
num_control_params (int, optional): Dimensionality of the conditioning signal. Default: 24
|
101 |
+
ninputs (int, optional): Number of input channels (mono = 1, stereo 2). Default: 1
|
102 |
+
noutputs (int, optional): Number of output channels (mono = 1, stereo 2). Default: 1
|
103 |
+
nblocks (int, optional): Number of total TCN blocks. Default: 10
|
104 |
+
kernel_size (int, optional: Width of the convolutional kernels. Default: 3
|
105 |
+
dialation_growth (int, optional): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1
|
106 |
+
channel_growth (int, optional): Compute the output channels at each black as in_ch * channel_growth. Default: 2
|
107 |
+
channel_width (int, optional): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64
|
108 |
+
stack_size (int, optional): Number of blocks that constitute a single stack of blocks. Default: 10
|
109 |
+
causal (bool, optional): Causal TCN configuration does not consider future input values. Default: False
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
sample_rate,
|
115 |
+
num_control_params=24,
|
116 |
+
ninputs=1,
|
117 |
+
noutputs=1,
|
118 |
+
nblocks=10,
|
119 |
+
kernel_size=15,
|
120 |
+
dilation_growth=2,
|
121 |
+
channel_growth=1,
|
122 |
+
channel_width=64,
|
123 |
+
stack_size=10,
|
124 |
+
causal=False,
|
125 |
+
skip_connections=False,
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.num_control_params = num_control_params
|
130 |
+
self.ninputs = ninputs
|
131 |
+
self.noutputs = noutputs
|
132 |
+
self.nblocks = nblocks
|
133 |
+
self.kernel_size = kernel_size
|
134 |
+
self.dilation_growth = dilation_growth
|
135 |
+
self.channel_growth = channel_growth
|
136 |
+
self.channel_width = channel_width
|
137 |
+
self.stack_size = stack_size
|
138 |
+
self.causal = causal
|
139 |
+
self.skip_connections = skip_connections
|
140 |
+
self.sample_rate = sample_rate
|
141 |
+
|
142 |
+
self.blocks = torch.nn.ModuleList()
|
143 |
+
for n in range(nblocks):
|
144 |
+
in_ch = out_ch if n > 0 else ninputs
|
145 |
+
|
146 |
+
if self.channel_growth > 1:
|
147 |
+
out_ch = in_ch * self.channel_growth
|
148 |
+
else:
|
149 |
+
out_ch = self.channel_width
|
150 |
+
|
151 |
+
dilation = self.dilation_growth ** (n % self.stack_size)
|
152 |
+
|
153 |
+
self.blocks.append(
|
154 |
+
ConditionalTCNBlock(
|
155 |
+
in_ch,
|
156 |
+
out_ch,
|
157 |
+
self.num_control_params,
|
158 |
+
kernel_size=self.kernel_size,
|
159 |
+
dilation=dilation,
|
160 |
+
padding="same" if self.causal else "valid",
|
161 |
+
causal=self.causal,
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1)
|
166 |
+
self.receptive_field = self.compute_receptive_field()
|
167 |
+
# print(
|
168 |
+
# f"TCN receptive field: {self.receptive_field} samples",
|
169 |
+
# f" or {(self.receptive_field/self.sample_rate)*1e3:0.3f} ms",
|
170 |
+
# )
|
171 |
+
|
172 |
+
def forward(self, x, p, **kwargs):
|
173 |
+
|
174 |
+
# causally pad input signal
|
175 |
+
x = torch.nn.functional.pad(x, (self.receptive_field - 1, 0))
|
176 |
+
|
177 |
+
# iterate over blocks passing conditioning
|
178 |
+
for idx, block in enumerate(self.blocks):
|
179 |
+
x = block(x, p)
|
180 |
+
if self.skip_connections:
|
181 |
+
if idx == 0:
|
182 |
+
skips = x
|
183 |
+
else:
|
184 |
+
skips = center_crop(skips, x[-1]) + x
|
185 |
+
else:
|
186 |
+
skips = 0
|
187 |
+
|
188 |
+
# final 1x1 convolution to collapse channels
|
189 |
+
out = self.output(x + skips)
|
190 |
+
|
191 |
+
return out
|
192 |
+
|
193 |
+
def compute_receptive_field(self):
|
194 |
+
"""Compute the receptive field in samples."""
|
195 |
+
rf = self.kernel_size
|
196 |
+
for n in range(1, self.nblocks):
|
197 |
+
dilation = self.dilation_growth ** (n % self.stack_size)
|
198 |
+
rf = rf + ((self.kernel_size - 1) * dilation)
|
199 |
+
return rf
|
deepafx_st/processors/spsa/channel.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.multiprocessing as mp
|
4 |
+
|
5 |
+
from deepafx_st.processors.dsp.peq import ParametricEQ
|
6 |
+
from deepafx_st.processors.dsp.compressor import Compressor
|
7 |
+
from deepafx_st.processors.spsa.spsa_func import SPSAFunction
|
8 |
+
from deepafx_st.utils import rademacher
|
9 |
+
|
10 |
+
|
11 |
+
def dsp_func(x, p, dsp, sample_rate=24000):
|
12 |
+
|
13 |
+
(peq, comp), meta = dsp
|
14 |
+
|
15 |
+
p_peq = p[:meta]
|
16 |
+
p_comp = p[meta:]
|
17 |
+
|
18 |
+
y = peq(x, p_peq, sample_rate)
|
19 |
+
y = comp(y, p_comp, sample_rate)
|
20 |
+
|
21 |
+
return y
|
22 |
+
|
23 |
+
|
24 |
+
class SPSAChannel(torch.nn.Module):
|
25 |
+
"""
|
26 |
+
|
27 |
+
Args:
|
28 |
+
sample_rate (float): Sample rate of the plugin instance
|
29 |
+
parallel (bool, optional): Use parallel workers for DSP.
|
30 |
+
|
31 |
+
By default, this utilizes parallelized instances of the plugin channel,
|
32 |
+
where the number of workers is equal to the batch size.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
sample_rate: int,
|
38 |
+
parallel: bool = False,
|
39 |
+
batch_size: int = 8,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.batch_size = batch_size
|
44 |
+
self.parallel = parallel
|
45 |
+
|
46 |
+
if self.parallel:
|
47 |
+
self.apply_func = SPSAFunction.apply
|
48 |
+
|
49 |
+
procs = {}
|
50 |
+
for b in range(self.batch_size):
|
51 |
+
|
52 |
+
peq = ParametricEQ(sample_rate)
|
53 |
+
comp = Compressor(sample_rate)
|
54 |
+
dsp = ((peq, comp), peq.num_control_params)
|
55 |
+
|
56 |
+
parent_conn, child_conn = mp.Pipe()
|
57 |
+
p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp))
|
58 |
+
p.start()
|
59 |
+
procs[b] = [p, parent_conn, child_conn]
|
60 |
+
#print(b, p)
|
61 |
+
|
62 |
+
# Update stuff for external public members TODO: fix
|
63 |
+
self.ports = [peq.ports, comp.ports]
|
64 |
+
self.num_control_params = (
|
65 |
+
comp.num_control_params + peq.num_control_params
|
66 |
+
)
|
67 |
+
|
68 |
+
self.procs = procs
|
69 |
+
#print(self.procs)
|
70 |
+
|
71 |
+
else:
|
72 |
+
self.peq = ParametricEQ(sample_rate)
|
73 |
+
self.comp = Compressor(sample_rate)
|
74 |
+
self.apply_func = SPSAFunction.apply
|
75 |
+
self.ports = [self.peq.ports, self.comp.ports]
|
76 |
+
self.num_control_params = (
|
77 |
+
self.comp.num_control_params + self.peq.num_control_params
|
78 |
+
)
|
79 |
+
self.dsp = ((self.peq, self.comp), self.peq.num_control_params)
|
80 |
+
|
81 |
+
# add one param for wet/dry mix
|
82 |
+
# self.num_control_params += 1
|
83 |
+
|
84 |
+
def __del__(self):
|
85 |
+
if hasattr(self, "procs"):
|
86 |
+
for proc_idx, proc in self.procs.items():
|
87 |
+
#print(f"Closing {proc_idx}...")
|
88 |
+
proc[0].terminate()
|
89 |
+
|
90 |
+
def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
x (Tensor): Input signal with shape: [batch x channels x samples]
|
94 |
+
p (Tensor): Audio effect control parameters with shape: [batch x parameters]
|
95 |
+
epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
y (Tensor): Processed audio signal.
|
99 |
+
|
100 |
+
"""
|
101 |
+
if self.parallel:
|
102 |
+
y = self.apply_func(x, p, None, epsilon, self, sample_rate)
|
103 |
+
|
104 |
+
else:
|
105 |
+
# this will process on CPU in NumPy
|
106 |
+
y = self.apply_func(x, p, None, epsilon, self, sample_rate)
|
107 |
+
|
108 |
+
return y.type_as(x)
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def static_backward(dsp, value):
|
112 |
+
|
113 |
+
(
|
114 |
+
batch_index,
|
115 |
+
x,
|
116 |
+
params,
|
117 |
+
needs_input_grad,
|
118 |
+
needs_param_grad,
|
119 |
+
grad_output,
|
120 |
+
epsilon,
|
121 |
+
) = value
|
122 |
+
|
123 |
+
grads_input = None
|
124 |
+
grads_params = None
|
125 |
+
ps = params.shape[-1]
|
126 |
+
factors = [1.0]
|
127 |
+
|
128 |
+
# estimate gradient w.r.t input
|
129 |
+
if needs_input_grad:
|
130 |
+
delta_k = rademacher(x.shape).numpy()
|
131 |
+
J_plus = dsp_func(x + epsilon * delta_k, params, dsp)
|
132 |
+
J_minus = dsp_func(x - epsilon * delta_k, params, dsp)
|
133 |
+
grads_input = (J_plus - J_minus) / (2.0 * epsilon)
|
134 |
+
|
135 |
+
# estimate gradient w.r.t params
|
136 |
+
grads_params_runs = []
|
137 |
+
if needs_param_grad:
|
138 |
+
for factor in factors:
|
139 |
+
params_sublist = []
|
140 |
+
delta_k = rademacher(params.shape).numpy()
|
141 |
+
|
142 |
+
# compute output in two random directions of the parameter space
|
143 |
+
params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1)
|
144 |
+
J_plus = dsp_func(x, params_plus, dsp)
|
145 |
+
|
146 |
+
params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1)
|
147 |
+
J_minus = dsp_func(x, params_minus, dsp)
|
148 |
+
grad_param = J_plus - J_minus
|
149 |
+
|
150 |
+
# compute gradient for each parameter as a function of epsilon and random direction
|
151 |
+
for sub_p_idx in range(ps):
|
152 |
+
grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx])
|
153 |
+
params_sublist.append(np.sum(grad_output * grad_p))
|
154 |
+
|
155 |
+
grads_params = np.array(params_sublist)
|
156 |
+
grads_params_runs.append(grads_params)
|
157 |
+
|
158 |
+
# average gradients
|
159 |
+
grads_params = np.mean(grads_params_runs, axis=0)
|
160 |
+
|
161 |
+
return grads_input, grads_params
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def static_forward(dsp, value):
|
165 |
+
batch_index, x, p, sample_rate = value
|
166 |
+
y = dsp_func(x, p, dsp, sample_rate)
|
167 |
+
return y
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def worker_pipe(child_conn, dsp):
|
171 |
+
|
172 |
+
while True:
|
173 |
+
msg, value = child_conn.recv()
|
174 |
+
if msg == "forward":
|
175 |
+
child_conn.send(SPSAChannel.static_forward(dsp, value))
|
176 |
+
elif msg == "backward":
|
177 |
+
child_conn.send(SPSAChannel.static_backward(dsp, value))
|
178 |
+
elif msg == "shutdown":
|
179 |
+
break
|
deepafx_st/processors/spsa/eps_scheduler.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class EpsilonScheduler:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
epsilon: float = 0.001,
|
8 |
+
patience: int = 10,
|
9 |
+
factor: float = 0.5,
|
10 |
+
verbose: bool = False,
|
11 |
+
):
|
12 |
+
self.epsilon = epsilon
|
13 |
+
self.patience = patience
|
14 |
+
self.factor = factor
|
15 |
+
self.best = 1e16
|
16 |
+
self.count = 0
|
17 |
+
self.verbose = verbose
|
18 |
+
|
19 |
+
def step(self, metric: float):
|
20 |
+
|
21 |
+
if metric < self.best:
|
22 |
+
self.best = metric
|
23 |
+
self.count = 0
|
24 |
+
else:
|
25 |
+
self.count += 1
|
26 |
+
if self.verbose:
|
27 |
+
print(f"Train loss has not improved for {self.count} epochs.")
|
28 |
+
if self.count >= self.patience:
|
29 |
+
self.count = 0
|
30 |
+
self.epsilon *= self.factor
|
31 |
+
if self.verbose:
|
32 |
+
print(f"Reducing epsilon to {self.epsilon:0.2e}...")
|
deepafx_st/processors/spsa/spsa_func.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def spsa_func(input, params, process, i, sample_rate=24000):
|
5 |
+
return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input)
|
6 |
+
|
7 |
+
|
8 |
+
class SPSAFunction(torch.autograd.Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(
|
11 |
+
ctx,
|
12 |
+
input,
|
13 |
+
params,
|
14 |
+
process,
|
15 |
+
epsilon,
|
16 |
+
thread_context,
|
17 |
+
sample_rate=24000,
|
18 |
+
):
|
19 |
+
"""Apply processor to a batch of tensors using given parameters.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
input (Tensor): Audio with shape: batch x 2 x samples
|
23 |
+
params (Tensor): Processor parameters with shape: batch x params
|
24 |
+
process (function): Function that will apply processing.
|
25 |
+
epsilon (float): Perturbation strength for SPSA computation.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
output (Tensor): Processed audio with same shape as input.
|
29 |
+
"""
|
30 |
+
ctx.save_for_backward(input, params)
|
31 |
+
ctx.epsilon = epsilon
|
32 |
+
ctx.process = process
|
33 |
+
ctx.thread_context = thread_context
|
34 |
+
|
35 |
+
if thread_context.parallel:
|
36 |
+
|
37 |
+
for i in range(input.shape[0]):
|
38 |
+
msg = (
|
39 |
+
"forward",
|
40 |
+
(
|
41 |
+
i,
|
42 |
+
input[i].view(-1).detach().cpu().numpy(),
|
43 |
+
params[i].view(-1).detach().cpu().numpy(),
|
44 |
+
sample_rate,
|
45 |
+
),
|
46 |
+
)
|
47 |
+
thread_context.procs[i][1].send(msg)
|
48 |
+
|
49 |
+
z = torch.empty_like(input)
|
50 |
+
for i in range(input.shape[0]):
|
51 |
+
z[i] = torch.from_numpy(thread_context.procs[i][1].recv())
|
52 |
+
else:
|
53 |
+
z = torch.empty_like(input)
|
54 |
+
for i in range(input.shape[0]):
|
55 |
+
value = (
|
56 |
+
i,
|
57 |
+
input[i].view(-1).detach().cpu().numpy(),
|
58 |
+
params[i].view(-1).detach().cpu().numpy(),
|
59 |
+
sample_rate,
|
60 |
+
)
|
61 |
+
z[i] = torch.from_numpy(
|
62 |
+
thread_context.static_forward(thread_context.dsp, value)
|
63 |
+
)
|
64 |
+
|
65 |
+
return z
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad_output):
|
69 |
+
"""Estimate gradients using SPSA."""
|
70 |
+
|
71 |
+
input, params = ctx.saved_tensors
|
72 |
+
epsilon = ctx.epsilon
|
73 |
+
needs_input_grad = ctx.needs_input_grad[0]
|
74 |
+
needs_param_grad = ctx.needs_input_grad[1]
|
75 |
+
thread_context = ctx.thread_context
|
76 |
+
|
77 |
+
grads_input = None
|
78 |
+
grads_params = None
|
79 |
+
|
80 |
+
# Receive grads
|
81 |
+
if needs_input_grad:
|
82 |
+
grads_input = torch.empty_like(input)
|
83 |
+
if needs_param_grad:
|
84 |
+
grads_params = torch.empty_like(params)
|
85 |
+
|
86 |
+
if thread_context.parallel:
|
87 |
+
|
88 |
+
for i in range(input.shape[0]):
|
89 |
+
msg = (
|
90 |
+
"backward",
|
91 |
+
(
|
92 |
+
i,
|
93 |
+
input[i].view(-1).detach().cpu().numpy(),
|
94 |
+
params[i].view(-1).detach().cpu().numpy(),
|
95 |
+
needs_input_grad,
|
96 |
+
needs_param_grad,
|
97 |
+
grad_output[i].view(-1).detach().cpu().numpy(),
|
98 |
+
epsilon,
|
99 |
+
),
|
100 |
+
)
|
101 |
+
thread_context.procs[i][1].send(msg)
|
102 |
+
|
103 |
+
# Wait for output
|
104 |
+
for i in range(input.shape[0]):
|
105 |
+
temp1, temp2 = thread_context.procs[i][1].recv()
|
106 |
+
|
107 |
+
if temp1 is not None:
|
108 |
+
grads_input[i] = torch.from_numpy(temp1)
|
109 |
+
|
110 |
+
if temp2 is not None:
|
111 |
+
grads_params[i] = torch.from_numpy(temp2)
|
112 |
+
|
113 |
+
return grads_input, grads_params, None, None, None, None
|
114 |
+
else:
|
115 |
+
for i in range(input.shape[0]):
|
116 |
+
value = (
|
117 |
+
i,
|
118 |
+
input[i].view(-1).detach().cpu().numpy(),
|
119 |
+
params[i].view(-1).detach().cpu().numpy(),
|
120 |
+
needs_input_grad,
|
121 |
+
needs_param_grad,
|
122 |
+
grad_output[i].view(-1).detach().cpu().numpy(),
|
123 |
+
epsilon,
|
124 |
+
)
|
125 |
+
temp1, temp2 = thread_context.static_backward(thread_context.dsp, value)
|
126 |
+
if temp1 is not None:
|
127 |
+
grads_input[i] = torch.from_numpy(temp1)
|
128 |
+
|
129 |
+
if temp2 is not None:
|
130 |
+
grads_params[i] = torch.from_numpy(temp2)
|
131 |
+
return grads_input, grads_params, None, None, None, None
|
deepafx_st/system.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import auraloss
|
3 |
+
import torchaudio
|
4 |
+
from itertools import chain
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from typing import Tuple, List, Dict
|
8 |
+
|
9 |
+
import deepafx_st.utils as utils
|
10 |
+
from deepafx_st.utils import DSPMode
|
11 |
+
from deepafx_st.data.dataset import AudioDataset
|
12 |
+
from deepafx_st.models.encoder import SpectralEncoder
|
13 |
+
from deepafx_st.models.controller import StyleTransferController
|
14 |
+
from deepafx_st.processors.spsa.channel import SPSAChannel
|
15 |
+
from deepafx_st.processors.spsa.eps_scheduler import EpsilonScheduler
|
16 |
+
from deepafx_st.processors.proxy.channel import ProxyChannel
|
17 |
+
from deepafx_st.processors.autodiff.channel import AutodiffChannel
|
18 |
+
|
19 |
+
|
20 |
+
class System(pl.LightningModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
ext="wav",
|
24 |
+
dsp_sample_rate=24000,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.save_hyperparameters()
|
29 |
+
|
30 |
+
self.eps_scheduler = EpsilonScheduler(
|
31 |
+
self.hparams.spsa_epsilon,
|
32 |
+
self.hparams.spsa_patience,
|
33 |
+
self.hparams.spsa_factor,
|
34 |
+
self.hparams.spsa_verbose,
|
35 |
+
)
|
36 |
+
|
37 |
+
self.hparams.dsp_mode = DSPMode.NONE
|
38 |
+
|
39 |
+
# first construct the processor, since this will dictate encoder
|
40 |
+
if self.hparams.processor_model == "spsa":
|
41 |
+
self.processor = SPSAChannel(
|
42 |
+
self.hparams.dsp_sample_rate,
|
43 |
+
self.hparams.spsa_parallel,
|
44 |
+
self.hparams.batch_size,
|
45 |
+
)
|
46 |
+
elif self.hparams.processor_model == "autodiff":
|
47 |
+
self.processor = AutodiffChannel(self.hparams.dsp_sample_rate)
|
48 |
+
elif self.hparams.processor_model == "proxy0":
|
49 |
+
# print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
|
50 |
+
self.hparams.dsp_mode = DSPMode.NONE
|
51 |
+
self.processor = ProxyChannel(
|
52 |
+
self.hparams.proxy_ckpts,
|
53 |
+
self.hparams.freeze_proxies,
|
54 |
+
self.hparams.dsp_mode,
|
55 |
+
sample_rate=self.hparams.dsp_sample_rate,
|
56 |
+
)
|
57 |
+
elif self.hparams.processor_model == "proxy1":
|
58 |
+
# print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
|
59 |
+
self.hparams.dsp_mode = DSPMode.INFER
|
60 |
+
self.processor = ProxyChannel(
|
61 |
+
self.hparams.proxy_ckpts,
|
62 |
+
self.hparams.freeze_proxies,
|
63 |
+
self.hparams.dsp_mode,
|
64 |
+
sample_rate=self.hparams.dsp_sample_rate,
|
65 |
+
)
|
66 |
+
elif self.hparams.processor_model == "proxy2":
|
67 |
+
# print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts)
|
68 |
+
self.hparams.dsp_mode = DSPMode.TRAIN_INFER
|
69 |
+
self.processor = ProxyChannel(
|
70 |
+
self.hparams.proxy_ckpts,
|
71 |
+
self.hparams.freeze_proxies,
|
72 |
+
self.hparams.dsp_mode,
|
73 |
+
sample_rate=self.hparams.dsp_sample_rate,
|
74 |
+
)
|
75 |
+
elif self.hparams.processor_model == "tcn1":
|
76 |
+
# self.processor = ConditionalTCN(self.hparams.sample_rate)
|
77 |
+
self.hparams.dsp_mode = DSPMode.NONE
|
78 |
+
self.processor = ProxyChannel(
|
79 |
+
[],
|
80 |
+
freeze_proxies=False,
|
81 |
+
dsp_mode=self.hparams.dsp_mode,
|
82 |
+
tcn_nblocks=self.hparams.tcn_nblocks,
|
83 |
+
tcn_dilation_growth=self.hparams.tcn_dilation_growth,
|
84 |
+
tcn_channel_width=self.hparams.tcn_channel_width,
|
85 |
+
tcn_kernel_size=self.hparams.tcn_kernel_size,
|
86 |
+
num_tcns=1,
|
87 |
+
sample_rate=self.hparams.sample_rate,
|
88 |
+
)
|
89 |
+
elif self.hparams.processor_model == "tcn2":
|
90 |
+
self.hparams.dsp_mode = DSPMode.NONE
|
91 |
+
self.processor = ProxyChannel(
|
92 |
+
[],
|
93 |
+
freeze_proxies=False,
|
94 |
+
dsp_mode=self.hparams.dsp_mode,
|
95 |
+
tcn_nblocks=self.hparams.tcn_nblocks,
|
96 |
+
tcn_dilation_growth=self.hparams.tcn_dilation_growth,
|
97 |
+
tcn_channel_width=self.hparams.tcn_channel_width,
|
98 |
+
tcn_kernel_size=self.hparams.tcn_kernel_size,
|
99 |
+
num_tcns=2,
|
100 |
+
sample_rate=self.hparams.sample_rate,
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
raise ValueError(f"Invalid processor_model: {self.hparams.processor_model}")
|
104 |
+
|
105 |
+
if self.hparams.encoder_ckpt is not None:
|
106 |
+
# load encoder weights from a pre-trained system
|
107 |
+
system = System.load_from_checkpoint(self.hparams.encoder_ckpt)
|
108 |
+
self.encoder = system.encoder
|
109 |
+
self.hparams.encoder_embed_dim = system.encoder.embed_dim
|
110 |
+
else:
|
111 |
+
self.encoder = SpectralEncoder(
|
112 |
+
self.processor.num_control_params,
|
113 |
+
self.hparams.sample_rate,
|
114 |
+
encoder_model=self.hparams.encoder_model,
|
115 |
+
embed_dim=self.hparams.encoder_embed_dim,
|
116 |
+
width_mult=self.hparams.encoder_width_mult,
|
117 |
+
)
|
118 |
+
|
119 |
+
if self.hparams.encoder_freeze:
|
120 |
+
for param in self.encoder.parameters():
|
121 |
+
param.requires_grad = False
|
122 |
+
|
123 |
+
self.controller = StyleTransferController(
|
124 |
+
self.processor.num_control_params,
|
125 |
+
self.hparams.encoder_embed_dim,
|
126 |
+
)
|
127 |
+
|
128 |
+
if len(self.hparams.recon_losses) != len(self.hparams.recon_loss_weights):
|
129 |
+
raise ValueError("Must supply same number of weights as losses.")
|
130 |
+
|
131 |
+
self.recon_losses = torch.nn.ModuleDict()
|
132 |
+
for recon_loss in self.hparams.recon_losses:
|
133 |
+
if recon_loss == "mrstft":
|
134 |
+
self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
|
135 |
+
fft_sizes=[32, 128, 512, 2048, 8192, 32768],
|
136 |
+
hop_sizes=[16, 64, 256, 1024, 4096, 16384],
|
137 |
+
win_lengths=[32, 128, 512, 2048, 8192, 32768],
|
138 |
+
w_sc=0.0,
|
139 |
+
w_phs=0.0,
|
140 |
+
w_lin_mag=1.0,
|
141 |
+
w_log_mag=1.0,
|
142 |
+
)
|
143 |
+
elif recon_loss == "mrstft-md":
|
144 |
+
self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
|
145 |
+
fft_sizes=[128, 512, 2048, 8192],
|
146 |
+
hop_sizes=[32, 128, 512, 2048], # 1 / 4
|
147 |
+
win_lengths=[128, 512, 2048, 8192],
|
148 |
+
w_sc=0.0,
|
149 |
+
w_phs=0.0,
|
150 |
+
w_lin_mag=1.0,
|
151 |
+
w_log_mag=1.0,
|
152 |
+
)
|
153 |
+
elif recon_loss == "mrstft-sm":
|
154 |
+
self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss(
|
155 |
+
fft_sizes=[512, 2048, 8192],
|
156 |
+
hop_sizes=[256, 1024, 4096], # 1 / 4
|
157 |
+
win_lengths=[512, 2048, 8192],
|
158 |
+
w_sc=0.0,
|
159 |
+
w_phs=0.0,
|
160 |
+
w_lin_mag=1.0,
|
161 |
+
w_log_mag=1.0,
|
162 |
+
)
|
163 |
+
elif recon_loss == "melfft":
|
164 |
+
self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss(
|
165 |
+
self.hparams.sample_rate,
|
166 |
+
fft_size=self.hparams.train_length,
|
167 |
+
hop_size=self.hparams.train_length // 2,
|
168 |
+
win_length=self.hparams.train_length,
|
169 |
+
n_mels=128,
|
170 |
+
w_sc=0.0,
|
171 |
+
device="cuda" if self.hparams.gpus > 0 else "cpu",
|
172 |
+
)
|
173 |
+
elif recon_loss == "melstft":
|
174 |
+
self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss(
|
175 |
+
self.hparams.sample_rate,
|
176 |
+
device="cuda" if self.hparams.gpus > 0 else "cpu",
|
177 |
+
)
|
178 |
+
elif recon_loss == "l1":
|
179 |
+
self.recon_losses[recon_loss] = torch.nn.L1Loss()
|
180 |
+
elif recon_loss == "sisdr":
|
181 |
+
self.recon_losses[recon_loss] = auraloss.time.SISDRLoss()
|
182 |
+
else:
|
183 |
+
raise ValueError(
|
184 |
+
f"Invalid reconstruction loss: {self.hparams.recon_losses}"
|
185 |
+
)
|
186 |
+
|
187 |
+
def forward(
|
188 |
+
self,
|
189 |
+
x: torch.Tensor,
|
190 |
+
y: torch.Tensor = None,
|
191 |
+
e_y: torch.Tensor = None,
|
192 |
+
z: torch.Tensor = None,
|
193 |
+
dsp_mode: DSPMode = DSPMode.NONE,
|
194 |
+
analysis_length: int = 0,
|
195 |
+
sample_rate: int = 24000,
|
196 |
+
):
|
197 |
+
"""Forward pass through the system subnetworks.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
x (tensor): Input audio tensor with shape (batch x 1 x samples)
|
201 |
+
y (tensor): Target audio tensor with shape (batch x 1 x samples)
|
202 |
+
e_y (tensor): Target embedding with shape (batch x edim)
|
203 |
+
z (tensor): Bottleneck latent.
|
204 |
+
dsp_mode (DSPMode): Mode of operation for the DSP blocks.
|
205 |
+
analysis_length (optional, int): Only analyze the first N samples.
|
206 |
+
sample_rate (optional, int): Desired sampling rate for the DSP blocks.
|
207 |
+
|
208 |
+
You must supply target audio `y`, `z`, or an embedding for the target `e_y`.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
y_hat (tensor): Output audio.
|
212 |
+
p (tensor):
|
213 |
+
e (tensor):
|
214 |
+
|
215 |
+
"""
|
216 |
+
bs, chs, samp = x.size()
|
217 |
+
|
218 |
+
if sample_rate != self.hparams.sample_rate:
|
219 |
+
x_enc = torchaudio.transforms.Resample(
|
220 |
+
sample_rate, self.hparams.sample_rate
|
221 |
+
).to(x.device)(x)
|
222 |
+
if y is not None:
|
223 |
+
y_enc = torchaudio.transforms.Resample(
|
224 |
+
sample_rate, self.hparams.sample_rate
|
225 |
+
).to(x.device)(y)
|
226 |
+
else:
|
227 |
+
x_enc = x
|
228 |
+
y_enc = y
|
229 |
+
|
230 |
+
if analysis_length > 0:
|
231 |
+
x_enc = x_enc[..., :analysis_length]
|
232 |
+
if y is not None:
|
233 |
+
y_enc = y_enc[..., :analysis_length]
|
234 |
+
|
235 |
+
e_x = self.encoder(x_enc) # generate latent embedding for input
|
236 |
+
|
237 |
+
if y is not None:
|
238 |
+
e_y = self.encoder(y_enc) # generate latent embedding for target
|
239 |
+
elif e_y is None:
|
240 |
+
raise RuntimeError("Must supply y, z, or e_y. None supplied.")
|
241 |
+
|
242 |
+
# learnable comparision
|
243 |
+
p = self.controller(e_x, e_y, z=z)
|
244 |
+
|
245 |
+
# process audio conditioned on parameters
|
246 |
+
# if there are multiple channels process them using same parameters
|
247 |
+
y_hat = torch.zeros(x.shape).type_as(x)
|
248 |
+
for ch_idx in range(chs):
|
249 |
+
y_hat_ch = self.processor(
|
250 |
+
x[:, ch_idx : ch_idx + 1, :],
|
251 |
+
p,
|
252 |
+
epsilon=self.eps_scheduler.epsilon,
|
253 |
+
dsp_mode=dsp_mode,
|
254 |
+
sample_rate=sample_rate,
|
255 |
+
)
|
256 |
+
y_hat[:, ch_idx : ch_idx + 1, :] = y_hat_ch
|
257 |
+
|
258 |
+
return y_hat, p, e_x
|
259 |
+
|
260 |
+
def common_paired_step(
|
261 |
+
self,
|
262 |
+
batch: Tuple,
|
263 |
+
batch_idx: int,
|
264 |
+
optimizer_idx: int = 0,
|
265 |
+
train: bool = False,
|
266 |
+
):
|
267 |
+
"""Model step used for validation and training.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
batch (Tuple[Tensor, Tensor]): Batch items containing input audio (x) and target audio (y).
|
271 |
+
batch_idx (int): Index of the batch within the current epoch.
|
272 |
+
optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer.
|
273 |
+
The firs optimizer corresponds to the generator and the second optimizer,
|
274 |
+
corresponds to the adversarial loss (when in use).
|
275 |
+
train (bool): Whether step is called during training (True) or validation (False).
|
276 |
+
"""
|
277 |
+
x, y = batch
|
278 |
+
loss = 0
|
279 |
+
dsp_mode = self.hparams.dsp_mode
|
280 |
+
|
281 |
+
if train and dsp_mode.INFER.name == DSPMode.INFER.name:
|
282 |
+
dsp_mode = DSPMode.NONE
|
283 |
+
|
284 |
+
# proces input audio through model
|
285 |
+
if self.hparams.style_transfer:
|
286 |
+
length = x.shape[-1]
|
287 |
+
|
288 |
+
x_A = x[..., : length // 2]
|
289 |
+
x_B = x[..., length // 2 :]
|
290 |
+
|
291 |
+
y_A = y[..., : length // 2]
|
292 |
+
y_B = y[..., length // 2 :]
|
293 |
+
|
294 |
+
if torch.rand(1).sum() > 0.5:
|
295 |
+
y_ref = y_B
|
296 |
+
y = y_A
|
297 |
+
x = x_A
|
298 |
+
else:
|
299 |
+
y_ref = y_A
|
300 |
+
y = y_B
|
301 |
+
x = x_B
|
302 |
+
|
303 |
+
y_hat, p, e = self(x, y=y_ref, dsp_mode=dsp_mode)
|
304 |
+
else:
|
305 |
+
y_ref = None
|
306 |
+
y_hat, p, e = self(x, dsp_mode=dsp_mode)
|
307 |
+
|
308 |
+
# compute reconstruction loss terms
|
309 |
+
for loss_idx, (loss_name, recon_loss_fn) in enumerate(
|
310 |
+
self.recon_losses.items()
|
311 |
+
):
|
312 |
+
temp_loss = recon_loss_fn(y_hat, y) # reconstruction loss
|
313 |
+
loss += float(self.hparams.recon_loss_weights[loss_idx]) * temp_loss
|
314 |
+
|
315 |
+
self.log(
|
316 |
+
("train" if train else "val") + f"_loss/{loss_name}",
|
317 |
+
temp_loss,
|
318 |
+
on_step=True,
|
319 |
+
on_epoch=True,
|
320 |
+
prog_bar=False,
|
321 |
+
logger=True,
|
322 |
+
sync_dist=True,
|
323 |
+
)
|
324 |
+
|
325 |
+
# log the overall aggregate loss
|
326 |
+
self.log(
|
327 |
+
("train" if train else "val") + "_loss/loss",
|
328 |
+
loss,
|
329 |
+
on_step=True,
|
330 |
+
on_epoch=True,
|
331 |
+
prog_bar=False,
|
332 |
+
logger=True,
|
333 |
+
sync_dist=True,
|
334 |
+
)
|
335 |
+
|
336 |
+
# store audio data
|
337 |
+
data_dict = {
|
338 |
+
"x": x.cpu(),
|
339 |
+
"y": y.cpu(),
|
340 |
+
"p": p.cpu(),
|
341 |
+
"e": e.cpu(),
|
342 |
+
"y_hat": y_hat.cpu(),
|
343 |
+
}
|
344 |
+
|
345 |
+
if y_ref is not None:
|
346 |
+
data_dict["y_ref"] = y_ref.cpu()
|
347 |
+
|
348 |
+
return loss, data_dict
|
349 |
+
|
350 |
+
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
351 |
+
loss, _ = self.common_paired_step(
|
352 |
+
batch,
|
353 |
+
batch_idx,
|
354 |
+
optimizer_idx,
|
355 |
+
train=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
return loss
|
359 |
+
|
360 |
+
def training_epoch_end(self, training_step_outputs):
|
361 |
+
if self.hparams.spsa_schedule and self.hparams.processor_model == "spsa":
|
362 |
+
self.eps_scheduler.step(
|
363 |
+
self.trainer.callback_metrics[self.hparams.train_monitor],
|
364 |
+
)
|
365 |
+
|
366 |
+
def validation_step(self, batch, batch_idx):
|
367 |
+
loss, data_dict = self.common_paired_step(batch, batch_idx)
|
368 |
+
|
369 |
+
return data_dict
|
370 |
+
|
371 |
+
def optimizer_step(
|
372 |
+
self,
|
373 |
+
epoch,
|
374 |
+
batch_idx,
|
375 |
+
optimizer,
|
376 |
+
optimizer_idx,
|
377 |
+
optimizer_closure,
|
378 |
+
on_tpu=False,
|
379 |
+
using_native_amp=False,
|
380 |
+
using_lbfgs=False,
|
381 |
+
):
|
382 |
+
if optimizer_idx == 0:
|
383 |
+
optimizer.step(closure=optimizer_closure)
|
384 |
+
|
385 |
+
def configure_optimizers(self):
|
386 |
+
# we need additional optimizer for the discriminator
|
387 |
+
optimizers = []
|
388 |
+
g_optimizer = torch.optim.Adam(
|
389 |
+
chain(
|
390 |
+
self.encoder.parameters(),
|
391 |
+
self.processor.parameters(),
|
392 |
+
self.controller.parameters(),
|
393 |
+
),
|
394 |
+
lr=self.hparams.lr,
|
395 |
+
betas=(0.9, 0.999),
|
396 |
+
)
|
397 |
+
optimizers.append(g_optimizer)
|
398 |
+
|
399 |
+
g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
400 |
+
g_optimizer,
|
401 |
+
patience=self.hparams.lr_patience,
|
402 |
+
verbose=True,
|
403 |
+
)
|
404 |
+
ms1 = int(self.hparams.max_epochs * 0.8)
|
405 |
+
ms2 = int(self.hparams.max_epochs * 0.95)
|
406 |
+
print(
|
407 |
+
"Learning rate schedule:",
|
408 |
+
f"0 {self.hparams.lr:0.2e} -> ",
|
409 |
+
f"{ms1} {self.hparams.lr*0.1:0.2e} -> ",
|
410 |
+
f"{ms2} {self.hparams.lr*0.01:0.2e}",
|
411 |
+
)
|
412 |
+
g_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
413 |
+
g_optimizer,
|
414 |
+
milestones=[ms1, ms2],
|
415 |
+
gamma=0.1,
|
416 |
+
)
|
417 |
+
|
418 |
+
lr_schedulers = {
|
419 |
+
"scheduler": g_scheduler,
|
420 |
+
}
|
421 |
+
|
422 |
+
return optimizers, lr_schedulers
|
423 |
+
|
424 |
+
def train_dataloader(self):
|
425 |
+
|
426 |
+
train_dataset = AudioDataset(
|
427 |
+
self.hparams.audio_dir,
|
428 |
+
subset="train",
|
429 |
+
train_frac=self.hparams.train_frac,
|
430 |
+
half=self.hparams.half,
|
431 |
+
length=self.hparams.train_length,
|
432 |
+
input_dirs=self.hparams.input_dirs,
|
433 |
+
random_scale_input=self.hparams.random_scale_input,
|
434 |
+
random_scale_target=self.hparams.random_scale_target,
|
435 |
+
buffer_size_gb=self.hparams.buffer_size_gb,
|
436 |
+
buffer_reload_rate=self.hparams.buffer_reload_rate,
|
437 |
+
num_examples_per_epoch=self.hparams.train_examples_per_epoch,
|
438 |
+
augmentations={
|
439 |
+
"pitch": {"sr": self.hparams.sample_rate},
|
440 |
+
"tempo": {"sr": self.hparams.sample_rate},
|
441 |
+
},
|
442 |
+
freq_corrupt=self.hparams.freq_corrupt,
|
443 |
+
drc_corrupt=self.hparams.drc_corrupt,
|
444 |
+
ext=self.hparams.ext,
|
445 |
+
)
|
446 |
+
|
447 |
+
g = torch.Generator()
|
448 |
+
g.manual_seed(0)
|
449 |
+
|
450 |
+
return torch.utils.data.DataLoader(
|
451 |
+
train_dataset,
|
452 |
+
num_workers=self.hparams.num_workers,
|
453 |
+
batch_size=self.hparams.batch_size,
|
454 |
+
worker_init_fn=utils.seed_worker,
|
455 |
+
generator=g,
|
456 |
+
pin_memory=True,
|
457 |
+
persistent_workers=True,
|
458 |
+
timeout=60,
|
459 |
+
)
|
460 |
+
|
461 |
+
def val_dataloader(self):
|
462 |
+
|
463 |
+
val_dataset = AudioDataset(
|
464 |
+
self.hparams.audio_dir,
|
465 |
+
subset="val",
|
466 |
+
half=self.hparams.half,
|
467 |
+
train_frac=self.hparams.train_frac,
|
468 |
+
length=self.hparams.val_length,
|
469 |
+
input_dirs=self.hparams.input_dirs,
|
470 |
+
buffer_size_gb=self.hparams.buffer_size_gb,
|
471 |
+
buffer_reload_rate=self.hparams.buffer_reload_rate,
|
472 |
+
random_scale_input=self.hparams.random_scale_input,
|
473 |
+
random_scale_target=self.hparams.random_scale_target,
|
474 |
+
num_examples_per_epoch=self.hparams.val_examples_per_epoch,
|
475 |
+
augmentations={},
|
476 |
+
freq_corrupt=self.hparams.freq_corrupt,
|
477 |
+
drc_corrupt=self.hparams.drc_corrupt,
|
478 |
+
ext=self.hparams.ext,
|
479 |
+
)
|
480 |
+
|
481 |
+
self.val_dataset = val_dataset
|
482 |
+
|
483 |
+
g = torch.Generator()
|
484 |
+
g.manual_seed(0)
|
485 |
+
|
486 |
+
return torch.utils.data.DataLoader(
|
487 |
+
val_dataset,
|
488 |
+
num_workers=1,
|
489 |
+
batch_size=self.hparams.batch_size,
|
490 |
+
worker_init_fn=utils.seed_worker,
|
491 |
+
generator=g,
|
492 |
+
pin_memory=True,
|
493 |
+
persistent_workers=True,
|
494 |
+
timeout=60,
|
495 |
+
)
|
496 |
+
def shutdown(self):
|
497 |
+
del self.processor
|
498 |
+
|
499 |
+
# add any model hyperparameters here
|
500 |
+
@staticmethod
|
501 |
+
def add_model_specific_args(parent_parser):
|
502 |
+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
503 |
+
# --- Training ---
|
504 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
505 |
+
parser.add_argument("--lr", type=float, default=3e-4)
|
506 |
+
parser.add_argument("--lr_patience", type=int, default=20)
|
507 |
+
parser.add_argument("--recon_losses", nargs="+", default=["l1"])
|
508 |
+
parser.add_argument("--recon_loss_weights", nargs="+", default=[1.0])
|
509 |
+
# --- Controller ---
|
510 |
+
parser.add_argument(
|
511 |
+
"--processor_model",
|
512 |
+
type=str,
|
513 |
+
help="autodiff, spsa, tcn1, tcn2, proxy0, proxy1, proxy2",
|
514 |
+
)
|
515 |
+
parser.add_argument("--controller_hidden_dim", type=int, default=256)
|
516 |
+
parser.add_argument("--style_transfer", action="store_true")
|
517 |
+
# --- Encoder ---
|
518 |
+
parser.add_argument("--encoder_model", type=str, default="mobilenet_v2")
|
519 |
+
parser.add_argument("--encoder_embed_dim", type=int, default=128)
|
520 |
+
parser.add_argument("--encoder_width_mult", type=int, default=2)
|
521 |
+
parser.add_argument("--encoder_ckpt", type=str, default=None)
|
522 |
+
parser.add_argument("--encoder_freeze", action="store_true", default=False)
|
523 |
+
# --- TCN ---
|
524 |
+
parser.add_argument("--tcn_causal", action="store_true")
|
525 |
+
parser.add_argument("--tcn_nblocks", type=int, default=4)
|
526 |
+
parser.add_argument("--tcn_dilation_growth", type=int, default=8)
|
527 |
+
parser.add_argument("--tcn_channel_width", type=int, default=32)
|
528 |
+
parser.add_argument("--tcn_kernel_size", type=int, default=13)
|
529 |
+
# --- SPSA ---
|
530 |
+
parser.add_argument("--plugin_config_file", type=str, default=None)
|
531 |
+
parser.add_argument("--spsa_epsilon", type=float, default=0.001)
|
532 |
+
parser.add_argument("--spsa_schedule", action="store_true")
|
533 |
+
parser.add_argument("--spsa_patience", type=int, default=10)
|
534 |
+
parser.add_argument("--spsa_verbose", action="store_true")
|
535 |
+
parser.add_argument("--spsa_factor", type=float, default=0.5)
|
536 |
+
parser.add_argument("--spsa_parallel", action="store_true")
|
537 |
+
# --- Proxy ----
|
538 |
+
parser.add_argument("--proxy_ckpts", nargs="+")
|
539 |
+
parser.add_argument("--freeze_proxies", action="store_true", default=False)
|
540 |
+
parser.add_argument("--use_dsp", action="store_true", default=False)
|
541 |
+
parser.add_argument("--dsp_mode", choices=DSPMode, type=DSPMode)
|
542 |
+
# --- Dataset ---
|
543 |
+
parser.add_argument("--audio_dir", type=str)
|
544 |
+
parser.add_argument("--ext", type=str, default="wav")
|
545 |
+
parser.add_argument("--input_dirs", nargs="+")
|
546 |
+
parser.add_argument("--buffer_reload_rate", type=int, default=1000)
|
547 |
+
parser.add_argument("--buffer_size_gb", type=float, default=1.0)
|
548 |
+
parser.add_argument("--sample_rate", type=int, default=24000)
|
549 |
+
parser.add_argument("--dsp_sample_rate", type=int, default=24000)
|
550 |
+
parser.add_argument("--shuffle", type=bool, default=True)
|
551 |
+
parser.add_argument("--random_scale_input", action="store_true")
|
552 |
+
parser.add_argument("--random_scale_target", action="store_true")
|
553 |
+
parser.add_argument("--freq_corrupt", action="store_true")
|
554 |
+
parser.add_argument("--drc_corrupt", action="store_true")
|
555 |
+
parser.add_argument("--train_length", type=int, default=65536)
|
556 |
+
parser.add_argument("--train_frac", type=float, default=0.8)
|
557 |
+
parser.add_argument("--half", action="store_true")
|
558 |
+
parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
|
559 |
+
parser.add_argument("--val_length", type=int, default=131072)
|
560 |
+
parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
|
561 |
+
parser.add_argument("--num_workers", type=int, default=16)
|
562 |
+
|
563 |
+
return parser
|
deepafx_st/utils.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from:
|
2 |
+
# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py
|
3 |
+
import os
|
4 |
+
import csv
|
5 |
+
import torch
|
6 |
+
import fnmatch
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
from enum import Enum
|
10 |
+
import pyloudnorm as pyln
|
11 |
+
|
12 |
+
|
13 |
+
class DSPMode(Enum):
|
14 |
+
NONE = "none"
|
15 |
+
TRAIN_INFER = "train_infer"
|
16 |
+
INFER = "infer"
|
17 |
+
|
18 |
+
def __str__(self):
|
19 |
+
return self.value
|
20 |
+
|
21 |
+
|
22 |
+
def loudness_normalize(x, sample_rate, target_loudness=-24.0):
|
23 |
+
x = x.view(1, -1)
|
24 |
+
stereo_audio = x.repeat(2, 1).permute(1, 0).numpy()
|
25 |
+
meter = pyln.Meter(sample_rate)
|
26 |
+
loudness = meter.integrated_loudness(stereo_audio)
|
27 |
+
norm_x = pyln.normalize.loudness(
|
28 |
+
stereo_audio,
|
29 |
+
loudness,
|
30 |
+
target_loudness,
|
31 |
+
)
|
32 |
+
x = torch.tensor(norm_x).permute(1, 0)
|
33 |
+
x = x[0, :].view(1, -1)
|
34 |
+
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
def get_random_file_id(keys):
|
39 |
+
# generate a random index into the keys of the input files
|
40 |
+
rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0]
|
41 |
+
# find the key (file_id) correponding to the random index
|
42 |
+
rand_input_file_id = list(keys)[rand_input_idx]
|
43 |
+
|
44 |
+
return rand_input_file_id
|
45 |
+
|
46 |
+
|
47 |
+
def get_random_patch(audio_file, length, check_silence=True):
|
48 |
+
silent = True
|
49 |
+
while silent:
|
50 |
+
start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
|
51 |
+
stop_idx = start_idx + length
|
52 |
+
patch = audio_file.audio[:, start_idx:stop_idx].clone().detach()
|
53 |
+
if (patch ** 2).mean() > 1e-4 or not check_silence:
|
54 |
+
silent = False
|
55 |
+
|
56 |
+
return start_idx, stop_idx
|
57 |
+
|
58 |
+
|
59 |
+
def seed_worker(worker_id):
|
60 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
61 |
+
np.random.seed(worker_seed)
|
62 |
+
random.seed(worker_seed)
|
63 |
+
|
64 |
+
|
65 |
+
def getFilesPath(directory, extension):
|
66 |
+
|
67 |
+
n_path = []
|
68 |
+
for path, subdirs, files in os.walk(directory):
|
69 |
+
for name in files:
|
70 |
+
if fnmatch.fnmatch(name, extension):
|
71 |
+
n_path.append(os.path.join(path, name))
|
72 |
+
n_path.sort()
|
73 |
+
|
74 |
+
return n_path
|
75 |
+
|
76 |
+
|
77 |
+
def count_parameters(model, trainable_only=True):
|
78 |
+
|
79 |
+
if trainable_only:
|
80 |
+
if len(list(model.parameters())) > 0:
|
81 |
+
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
82 |
+
else:
|
83 |
+
params = 0
|
84 |
+
else:
|
85 |
+
if len(list(model.parameters())) > 0:
|
86 |
+
params = sum(p.numel() for p in model.parameters())
|
87 |
+
else:
|
88 |
+
params = 0
|
89 |
+
|
90 |
+
return params
|
91 |
+
|
92 |
+
|
93 |
+
def system_summary(system):
|
94 |
+
print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M")
|
95 |
+
print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M")
|
96 |
+
|
97 |
+
if hasattr(system, "adv_loss_fn"):
|
98 |
+
for idx, disc in enumerate(system.adv_loss_fn.discriminators):
|
99 |
+
print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M")
|
100 |
+
|
101 |
+
|
102 |
+
def center_crop(x, length: int):
|
103 |
+
if x.shape[-1] != length:
|
104 |
+
start = (x.shape[-1] - length) // 2
|
105 |
+
stop = start + length
|
106 |
+
x = x[..., start:stop]
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def causal_crop(x, length: int):
|
111 |
+
if x.shape[-1] != length:
|
112 |
+
stop = x.shape[-1] - 1
|
113 |
+
start = stop - length
|
114 |
+
x = x[..., start:stop]
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
def denormalize(norm_val, max_val, min_val):
|
119 |
+
return (norm_val * (max_val - min_val)) + min_val
|
120 |
+
|
121 |
+
|
122 |
+
def normalize(denorm_val, max_val, min_val):
|
123 |
+
return (denorm_val - min_val) / (max_val - min_val)
|
124 |
+
|
125 |
+
|
126 |
+
def get_random_patch(audio_file, length, energy_treshold=1e-4):
|
127 |
+
"""Produce sample indicies for a random patch of size `length`.
|
128 |
+
|
129 |
+
This function will check the energy of the selected patch to
|
130 |
+
ensure that it is not complete silence. If silence is found,
|
131 |
+
it will continue searching for a non-silent patch.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
audio_file (AudioFile): Audio file object.
|
135 |
+
length (int): Number of samples in random patch.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
start_idx (int): Starting sample index
|
139 |
+
stop_idx (int): Stop sample index
|
140 |
+
"""
|
141 |
+
|
142 |
+
silent = True
|
143 |
+
while silent:
|
144 |
+
start_idx = int(torch.rand(1) * (audio_file.num_frames - length))
|
145 |
+
stop_idx = start_idx + length
|
146 |
+
patch = audio_file.audio[:, start_idx:stop_idx]
|
147 |
+
if (patch ** 2).mean() > energy_treshold:
|
148 |
+
silent = False
|
149 |
+
|
150 |
+
return start_idx, stop_idx
|
151 |
+
|
152 |
+
|
153 |
+
def split_dataset(file_list, subset, train_frac):
|
154 |
+
"""Given a list of files, split into train/val/test sets.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
file_list (list): List of audio files.
|
158 |
+
subset (str): One of "train", "val", or "test".
|
159 |
+
train_frac (float): Fraction of the dataset to use for training.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
file_list (list): List of audio files corresponding to subset.
|
163 |
+
"""
|
164 |
+
assert train_frac > 0.1 and train_frac < 1.0
|
165 |
+
|
166 |
+
total_num_examples = len(file_list)
|
167 |
+
|
168 |
+
train_num_examples = int(total_num_examples * train_frac)
|
169 |
+
val_num_examples = int(total_num_examples * (1 - train_frac) / 2)
|
170 |
+
test_num_examples = total_num_examples - (train_num_examples + val_num_examples)
|
171 |
+
|
172 |
+
if train_num_examples < 0:
|
173 |
+
raise ValueError(
|
174 |
+
f"No examples in training set. Try increasing train_frac: {train_frac}."
|
175 |
+
)
|
176 |
+
elif val_num_examples < 0:
|
177 |
+
raise ValueError(
|
178 |
+
f"No examples in validation set. Try decreasing train_frac: {train_frac}."
|
179 |
+
)
|
180 |
+
elif test_num_examples < 0:
|
181 |
+
raise ValueError(
|
182 |
+
f"No examples in test set. Try decreasing train_frac: {train_frac}."
|
183 |
+
)
|
184 |
+
|
185 |
+
if subset == "train":
|
186 |
+
start_idx = 0
|
187 |
+
stop_idx = train_num_examples
|
188 |
+
elif subset == "val":
|
189 |
+
start_idx = train_num_examples
|
190 |
+
stop_idx = start_idx + val_num_examples
|
191 |
+
elif subset == "test":
|
192 |
+
start_idx = train_num_examples + val_num_examples
|
193 |
+
stop_idx = start_idx + test_num_examples + 1
|
194 |
+
else:
|
195 |
+
raise ValueError("Invalid subset: {subset}.")
|
196 |
+
|
197 |
+
return file_list[start_idx:stop_idx]
|
198 |
+
|
199 |
+
|
200 |
+
def rademacher(size):
|
201 |
+
"""Generates random samples from a Rademacher distribution +-1
|
202 |
+
|
203 |
+
Args:
|
204 |
+
size (int):
|
205 |
+
|
206 |
+
"""
|
207 |
+
m = torch.distributions.binomial.Binomial(1, 0.5)
|
208 |
+
x = m.sample(size)
|
209 |
+
x[x == 0] = -1
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
def get_subset(csv_file):
|
214 |
+
subset_files = []
|
215 |
+
with open(csv_file) as fp:
|
216 |
+
reader = csv.DictReader(fp)
|
217 |
+
for row in reader:
|
218 |
+
subset_files.append(row["filepath"])
|
219 |
+
|
220 |
+
return list(set(subset_files))
|
221 |
+
|
222 |
+
|
223 |
+
def conform_length(x: torch.Tensor, length: int):
|
224 |
+
"""Crop or pad input on last dim to match `length`."""
|
225 |
+
if x.shape[-1] < length:
|
226 |
+
padsize = length - x.shape[-1]
|
227 |
+
x = torch.nn.functional.pad(x, (0, padsize))
|
228 |
+
elif x.shape[-1] > length:
|
229 |
+
x = x[..., :length]
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
def linear_fade(
|
235 |
+
x: torch.Tensor,
|
236 |
+
fade_ms: float = 50.0,
|
237 |
+
sample_rate: float = 22050,
|
238 |
+
):
|
239 |
+
"""Apply fade in and fade out to last dim."""
|
240 |
+
fade_samples = int(fade_ms * 1e-3 * 22050)
|
241 |
+
|
242 |
+
fade_in = torch.linspace(0.0, 1.0, steps=fade_samples)
|
243 |
+
fade_out = torch.linspace(1.0, 0.0, steps=fade_samples)
|
244 |
+
|
245 |
+
# fade in
|
246 |
+
x[..., :fade_samples] *= fade_in
|
247 |
+
|
248 |
+
# fade out
|
249 |
+
x[..., -fade_samples:] *= fade_out
|
250 |
+
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
# def get_random_patch(x, sample_rate, length_samples):
|
255 |
+
# length = length_samples
|
256 |
+
# silent = True
|
257 |
+
# while silent:
|
258 |
+
# start_idx = np.random.randint(0, x.shape[-1] - length - 1)
|
259 |
+
# stop_idx = start_idx + length
|
260 |
+
# x_crop = x[0:1, start_idx:stop_idx]
|
261 |
+
|
262 |
+
# # check for silence
|
263 |
+
# frames = length // sample_rate
|
264 |
+
# silent_frames = []
|
265 |
+
# for n in range(frames):
|
266 |
+
# start_idx = n * sample_rate
|
267 |
+
# stop_idx = start_idx + sample_rate
|
268 |
+
# x_frame = x_crop[0:1, start_idx:stop_idx]
|
269 |
+
# if (x_frame ** 2).mean() > 3e-4:
|
270 |
+
# silent_frames.append(False)
|
271 |
+
# else:
|
272 |
+
# silent_frames.append(True)
|
273 |
+
# silent = True if any(silent_frames) else False
|
274 |
+
|
275 |
+
# x_crop /= x_crop.abs().max()
|
276 |
+
|
277 |
+
# return x_crop
|
deepafx_st/version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
'''Version info'''
|
4 |
+
|
5 |
+
short_version = '0.0'
|
6 |
+
version = '0.0.1'
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
libsndfile1
|
2 |
+
sox
|
3 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/adobe-research/DeepAFx-ST.git
|
2 |
+
gradio
|
3 |
+
huggingface_hub
|