Spaces:
Runtime error
Runtime error
Commit
·
e2b0b28
1
Parent(s):
95d8ea8
rewrote demucs model
Browse fileschanged configs default values
refactoring
- app.py +1 -5
- conf/config.yaml +7 -5
- conf/model/demucs.yaml +1 -0
- datasets/__init__.py +1 -2
- denoisers/SpectralGating.py +2 -0
- denoisers/demucs.py +33 -31
- losses.py +3 -3
- main.py +0 -1
- testing/metrics.py +13 -6
app.py
CHANGED
@@ -24,15 +24,13 @@ def denoising_transform(audio):
|
|
24 |
)
|
25 |
model.predict(audio, tgt_path)
|
26 |
return tgt_path
|
27 |
-
|
28 |
-
# return tgt_path
|
29 |
|
30 |
|
31 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
32 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
33 |
|
34 |
title = "Denoising"
|
35 |
-
#"""
|
36 |
gr.Interface(
|
37 |
denoising_transform, inputs, outputs, title=title,
|
38 |
allow_flagging='never'
|
@@ -40,6 +38,4 @@ gr.Interface(
|
|
40 |
server_name='localhost',
|
41 |
server_port=7871,
|
42 |
share=True
|
43 |
-
#ssl_keyfile='example.key',
|
44 |
-
#ssl_certfile="example.crt",
|
45 |
)
|
|
|
24 |
)
|
25 |
model.predict(audio, tgt_path)
|
26 |
return tgt_path
|
27 |
+
|
|
|
28 |
|
29 |
|
30 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
31 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
32 |
|
33 |
title = "Denoising"
|
|
|
34 |
gr.Interface(
|
35 |
denoising_transform, inputs, outputs, title=title,
|
36 |
allow_flagging='never'
|
|
|
38 |
server_name='localhost',
|
39 |
server_port=7871,
|
40 |
share=True
|
|
|
|
|
41 |
)
|
conf/config.yaml
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
defaults:
|
2 |
- dataset: valentini
|
3 |
- model: demucs
|
4 |
-
- loss:
|
5 |
- optimizer: sgd
|
6 |
|
7 |
training:
|
8 |
-
num_epochs:
|
9 |
model_save_path: /media/public/checkpoints
|
10 |
|
11 |
dataloader:
|
@@ -13,6 +13,8 @@ dataloader:
|
|
13 |
sample_rate: 16000
|
14 |
train_batch_size: 12
|
15 |
valid_batch_size: 12
|
|
|
|
|
16 |
|
17 |
validation:
|
18 |
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
@@ -20,13 +22,13 @@ validation:
|
|
20 |
|
21 |
|
22 |
wandb:
|
|
|
23 |
run_name: default
|
24 |
project: denoising
|
25 |
log_interval: 100
|
26 |
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
27 |
host: http://localhost:8080/
|
28 |
notes: "Experiment note"
|
29 |
-
|
30 |
-
- baseline
|
31 |
|
32 |
-
gpu:
|
|
|
1 |
defaults:
|
2 |
- dataset: valentini
|
3 |
- model: demucs
|
4 |
+
- loss: L1_Multi_STFT
|
5 |
- optimizer: sgd
|
6 |
|
7 |
training:
|
8 |
+
num_epochs: 10
|
9 |
model_save_path: /media/public/checkpoints
|
10 |
|
11 |
dataloader:
|
|
|
13 |
sample_rate: 16000
|
14 |
train_batch_size: 12
|
15 |
valid_batch_size: 12
|
16 |
+
normalize: False
|
17 |
+
num_workers: 12
|
18 |
|
19 |
validation:
|
20 |
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
|
|
22 |
|
23 |
|
24 |
wandb:
|
25 |
+
entity: borisovmaksim
|
26 |
run_name: default
|
27 |
project: denoising
|
28 |
log_interval: 100
|
29 |
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
30 |
host: http://localhost:8080/
|
31 |
notes: "Experiment note"
|
32 |
+
resume: False
|
|
|
33 |
|
34 |
+
gpu: 1
|
conf/model/demucs.yaml
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
demucs:
|
2 |
H: 64
|
|
|
3 |
encoder:
|
4 |
conv1:
|
5 |
kernel_size: 8
|
|
|
1 |
demucs:
|
2 |
H: 64
|
3 |
+
L: 3
|
4 |
encoder:
|
5 |
conv1:
|
6 |
kernel_size: 8
|
datasets/__init__.py
CHANGED
@@ -9,10 +9,9 @@ DATASETS_POOL = {
|
|
9 |
}
|
10 |
|
11 |
|
12 |
-
|
13 |
def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
|
14 |
name, dataset_params = list(cfg['dataset'].items())[0]
|
15 |
-
transform = Transform(
|
16 |
train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
|
17 |
valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
|
18 |
return train_dataset, valid_dataset
|
|
|
9 |
}
|
10 |
|
11 |
|
|
|
12 |
def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
|
13 |
name, dataset_params = list(cfg['dataset'].items())[0]
|
14 |
+
transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
|
15 |
train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
|
16 |
valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
|
17 |
return train_dataset, valid_dataset
|
denoisers/SpectralGating.py
CHANGED
@@ -18,6 +18,8 @@ class SpectralGating(torch.nn.Module):
|
|
18 |
torchaudio.save(out_path, reduced_noise, rate)
|
19 |
return out_path
|
20 |
|
|
|
|
|
21 |
|
22 |
|
23 |
|
|
|
18 |
torchaudio.save(out_path, reduced_noise, rate)
|
19 |
return out_path
|
20 |
|
21 |
+
|
22 |
+
|
23 |
|
24 |
|
25 |
|
denoisers/demucs.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
from torch.nn.functional import pad
|
3 |
-
|
|
|
4 |
|
5 |
|
6 |
class Encoder(torch.nn.Module):
|
@@ -25,9 +26,9 @@ class Encoder(torch.nn.Module):
|
|
25 |
|
26 |
|
27 |
class Decoder(torch.nn.Module):
|
28 |
-
def __init__(self, in_channels, out_channels, cfg):
|
29 |
super(Decoder, self).__init__()
|
30 |
-
|
31 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
32 |
kernel_size=cfg['conv1']['kernel_size'],
|
33 |
stride=cfg['conv1']['stride'])
|
@@ -39,45 +40,46 @@ class Decoder(torch.nn.Module):
|
|
39 |
|
40 |
def forward(self, x):
|
41 |
x = self.glu(self.conv1(x))
|
42 |
-
x = self.
|
|
|
|
|
43 |
return x
|
44 |
|
45 |
|
46 |
class Demucs(torch.nn.Module):
|
47 |
def __init__(self, cfg):
|
48 |
super(Demucs, self).__init__()
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
self.lstm = torch.nn.LSTM(
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
self.decoder1 = Decoder(in_channels=4*H, out_channels=2*H, cfg=cfg['decoder'])
|
60 |
-
self.decoder2 = Decoder(in_channels=2*H, out_channels=H, cfg=cfg['decoder'])
|
61 |
-
self.decoder3 = Decoder(in_channels=H, out_channels=1, cfg=cfg['decoder'])
|
62 |
|
63 |
def forward(self, x):
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
67 |
|
68 |
-
x, _ = self.lstm(
|
69 |
x = x.permute(0, 2, 1)
|
70 |
-
x = self.decoder1(x + out3)
|
71 |
-
x = x[:, :, :out2.shape[-1]]
|
72 |
-
x = self.decoder2(x + out2)
|
73 |
-
x = x[:, :, :-1]
|
74 |
-
out1 = out1[:, :, :-1]
|
75 |
-
if x.shape[-1] > out1.shape[-1]:
|
76 |
-
x = x[:, :, :out1.shape[-1]]
|
77 |
-
elif x.shape[-1] < out1.shape[-1]:
|
78 |
-
out1 = out1[:, :, :x.shape[-1]]
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
return x
|
82 |
-
|
83 |
-
|
|
|
1 |
import torch
|
2 |
from torch.nn.functional import pad
|
3 |
+
from utils import pad_cut_batch_audio
|
4 |
+
import torch.nn as nn
|
5 |
|
6 |
|
7 |
class Encoder(torch.nn.Module):
|
|
|
26 |
|
27 |
|
28 |
class Decoder(torch.nn.Module):
|
29 |
+
def __init__(self, in_channels, out_channels, cfg, is_last=False):
|
30 |
super(Decoder, self).__init__()
|
31 |
+
self.is_last = is_last
|
32 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
33 |
kernel_size=cfg['conv1']['kernel_size'],
|
34 |
stride=cfg['conv1']['stride'])
|
|
|
40 |
|
41 |
def forward(self, x):
|
42 |
x = self.glu(self.conv1(x))
|
43 |
+
x = self.conv2(x)
|
44 |
+
if not self.is_last:
|
45 |
+
x = self.relu(x)
|
46 |
return x
|
47 |
|
48 |
|
49 |
class Demucs(torch.nn.Module):
|
50 |
def __init__(self, cfg):
|
51 |
super(Demucs, self).__init__()
|
52 |
+
self.L = cfg['L']
|
53 |
+
|
54 |
+
encoders = [Encoder(in_channels=1, out_channels=cfg['H'], cfg=cfg['encoder'])]
|
55 |
+
decoders = [Decoder(in_channels=cfg['H'], out_channels=1, cfg=cfg['decoder'], is_last=True)]
|
56 |
+
for i in range(self.L - 1):
|
57 |
+
encoders.append(Encoder(in_channels=(2 ** i) * cfg['H'],
|
58 |
+
out_channels=(2 ** (i + 1)) * cfg['H'],
|
59 |
+
cfg=cfg['encoder']))
|
60 |
+
decoders.append(Decoder(in_channels=(2 ** (i + 1)) * cfg['H'],
|
61 |
+
out_channels=(2 ** i) * cfg['H'],
|
62 |
+
cfg=cfg['decoder']))
|
63 |
+
self.encoders = nn.ModuleList(encoders)
|
64 |
+
self.decoders = nn.ModuleList(decoders)
|
65 |
|
66 |
self.lstm = torch.nn.LSTM(
|
67 |
+
input_size=(2 ** (self.L - 1)) * cfg['H'],
|
68 |
+
hidden_size=(2 ** (self.L - 1)) * cfg['H'], num_layers=2, batch_first=True)
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def forward(self, x):
|
71 |
+
outs = [x]
|
72 |
+
for i in range(self.L):
|
73 |
+
out = self.encoders[i](outs[-1])
|
74 |
+
outs.append(out)
|
75 |
+
model_input = outs.pop(0)
|
76 |
|
77 |
+
x, _ = self.lstm(outs[-1].permute(0, 2, 1))
|
78 |
x = x.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
for i in reversed(range(self.L)):
|
81 |
+
decoder = self.decoders[i]
|
82 |
+
x = pad_cut_batch_audio(x, outs[i].shape)
|
83 |
+
x = decoder(x + outs[i])
|
84 |
+
x = pad_cut_batch_audio(x, model_input.shape)
|
85 |
return x
|
|
|
|
losses.py
CHANGED
@@ -107,9 +107,9 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
|
|
107 |
"""Multi resolution STFT loss module."""
|
108 |
|
109 |
def __init__(self,
|
110 |
-
fft_sizes=
|
111 |
-
hop_sizes=
|
112 |
-
win_lengths=
|
113 |
window="hann_window", factor_sc=0.1, factor_mag=0.1):
|
114 |
"""Initialize Multi resolution STFT loss module.
|
115 |
Args:
|
|
|
107 |
"""Multi resolution STFT loss module."""
|
108 |
|
109 |
def __init__(self,
|
110 |
+
fft_sizes=(1024, 2048, 512),
|
111 |
+
hop_sizes=(120, 240, 50),
|
112 |
+
win_lengths=(600, 1200, 240),
|
113 |
window="hann_window", factor_sc=0.1, factor_mag=0.1):
|
114 |
"""Initialize Multi resolution STFT loss module.
|
115 |
Args:
|
main.py
CHANGED
@@ -3,7 +3,6 @@ from omegaconf import DictConfig
|
|
3 |
from train import train
|
4 |
|
5 |
|
6 |
-
|
7 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
8 |
def main(cfg: DictConfig):
|
9 |
train(cfg)
|
|
|
3 |
from train import train
|
4 |
|
5 |
|
|
|
6 |
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
7 |
def main(cfg: DictConfig):
|
8 |
train(cfg)
|
testing/metrics.py
CHANGED
@@ -1,19 +1,27 @@
|
|
1 |
import pesq
|
2 |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
3 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
|
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
from torchmetrics import SignalNoiseRatio
|
7 |
|
8 |
-
class Metrics:
|
9 |
-
def __init__(self,
|
10 |
-
|
11 |
-
self.
|
|
|
|
|
|
|
|
|
12 |
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
-
def
|
15 |
pesq_scores, stoi_scores = 0, 0
|
16 |
for denoised_wav, clean_wav in zip(denoised, clean):
|
|
|
|
|
|
|
17 |
try:
|
18 |
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
|
19 |
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
|
@@ -22,7 +30,6 @@ class Metrics:
|
|
22 |
except ValueError as e:
|
23 |
print(e)
|
24 |
|
25 |
-
|
26 |
return {'PESQ': pesq_scores,
|
27 |
'STOI': stoi_scores}
|
28 |
|
|
|
1 |
import pesq
|
2 |
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
3 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
4 |
+
from torchaudio.transforms import Resample
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
from torchmetrics import SignalNoiseRatio
|
8 |
|
9 |
+
class Metrics(torch.nn.Module):
|
10 |
+
def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
|
11 |
+
super().__init__(*args, **kwargs)
|
12 |
+
self.source_rate = source_rate
|
13 |
+
self.target_rate = target_rate
|
14 |
+
self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
|
15 |
+
self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
|
16 |
+
self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
|
17 |
self.snr = SignalNoiseRatio()
|
18 |
|
19 |
+
def forward(self, denoised, clean):
|
20 |
pesq_scores, stoi_scores = 0, 0
|
21 |
for denoised_wav, clean_wav in zip(denoised, clean):
|
22 |
+
if self.source_rate != self.target_rate:
|
23 |
+
denoised_wav = self.resampler(denoised_wav)
|
24 |
+
clean_wav = self.resampler(clean_wav)
|
25 |
try:
|
26 |
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
|
27 |
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
|
|
|
30 |
except ValueError as e:
|
31 |
print(e)
|
32 |
|
|
|
33 |
return {'PESQ': pesq_scores,
|
34 |
'STOI': stoi_scores}
|
35 |
|