BorisovMaksim commited on
Commit
e2b0b28
·
1 Parent(s): 95d8ea8

rewrote demucs model

Browse files

changed configs default values
refactoring

app.py CHANGED
@@ -24,15 +24,13 @@ def denoising_transform(audio):
24
  )
25
  model.predict(audio, tgt_path)
26
  return tgt_path
27
- # model.predict(src_path, tgt_path)
28
- # return tgt_path
29
 
30
 
31
  inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
32
  outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
33
 
34
  title = "Denoising"
35
- #"""
36
  gr.Interface(
37
  denoising_transform, inputs, outputs, title=title,
38
  allow_flagging='never'
@@ -40,6 +38,4 @@ gr.Interface(
40
  server_name='localhost',
41
  server_port=7871,
42
  share=True
43
- #ssl_keyfile='example.key',
44
- #ssl_certfile="example.crt",
45
  )
 
24
  )
25
  model.predict(audio, tgt_path)
26
  return tgt_path
27
+
 
28
 
29
 
30
  inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
31
  outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
32
 
33
  title = "Denoising"
 
34
  gr.Interface(
35
  denoising_transform, inputs, outputs, title=title,
36
  allow_flagging='never'
 
38
  server_name='localhost',
39
  server_port=7871,
40
  share=True
 
 
41
  )
conf/config.yaml CHANGED
@@ -1,11 +1,11 @@
1
  defaults:
2
  - dataset: valentini
3
  - model: demucs
4
- - loss: mse
5
  - optimizer: sgd
6
 
7
  training:
8
- num_epochs: 5
9
  model_save_path: /media/public/checkpoints
10
 
11
  dataloader:
@@ -13,6 +13,8 @@ dataloader:
13
  sample_rate: 16000
14
  train_batch_size: 12
15
  valid_batch_size: 12
 
 
16
 
17
  validation:
18
  path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
@@ -20,13 +22,13 @@ validation:
20
 
21
 
22
  wandb:
 
23
  run_name: default
24
  project: denoising
25
  log_interval: 100
26
  api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
27
  host: http://localhost:8080/
28
  notes: "Experiment note"
29
- tags:
30
- - baseline
31
 
32
- gpu: 0
 
1
  defaults:
2
  - dataset: valentini
3
  - model: demucs
4
+ - loss: L1_Multi_STFT
5
  - optimizer: sgd
6
 
7
  training:
8
+ num_epochs: 10
9
  model_save_path: /media/public/checkpoints
10
 
11
  dataloader:
 
13
  sample_rate: 16000
14
  train_batch_size: 12
15
  valid_batch_size: 12
16
+ normalize: False
17
+ num_workers: 12
18
 
19
  validation:
20
  path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
 
22
 
23
 
24
  wandb:
25
+ entity: borisovmaksim
26
  run_name: default
27
  project: denoising
28
  log_interval: 100
29
  api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
30
  host: http://localhost:8080/
31
  notes: "Experiment note"
32
+ resume: False
 
33
 
34
+ gpu: 1
conf/model/demucs.yaml CHANGED
@@ -1,5 +1,6 @@
1
  demucs:
2
  H: 64
 
3
  encoder:
4
  conv1:
5
  kernel_size: 8
 
1
  demucs:
2
  H: 64
3
+ L: 3
4
  encoder:
5
  conv1:
6
  kernel_size: 8
datasets/__init__.py CHANGED
@@ -9,10 +9,9 @@ DATASETS_POOL = {
9
  }
10
 
11
 
12
-
13
  def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
14
  name, dataset_params = list(cfg['dataset'].items())[0]
15
- transform = Transform(input_sr=dataset_params['sample_rate'], **cfg['dataloader'])
16
  train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
17
  valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
18
  return train_dataset, valid_dataset
 
9
  }
10
 
11
 
 
12
  def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
13
  name, dataset_params = list(cfg['dataset'].items())[0]
14
+ transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
15
  train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
16
  valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
17
  return train_dataset, valid_dataset
denoisers/SpectralGating.py CHANGED
@@ -18,6 +18,8 @@ class SpectralGating(torch.nn.Module):
18
  torchaudio.save(out_path, reduced_noise, rate)
19
  return out_path
20
 
 
 
21
 
22
 
23
 
 
18
  torchaudio.save(out_path, reduced_noise, rate)
19
  return out_path
20
 
21
+
22
+
23
 
24
 
25
 
denoisers/demucs.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from torch.nn.functional import pad
3
-
 
4
 
5
 
6
  class Encoder(torch.nn.Module):
@@ -25,9 +26,9 @@ class Encoder(torch.nn.Module):
25
 
26
 
27
  class Decoder(torch.nn.Module):
28
- def __init__(self, in_channels, out_channels, cfg):
29
  super(Decoder, self).__init__()
30
-
31
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
32
  kernel_size=cfg['conv1']['kernel_size'],
33
  stride=cfg['conv1']['stride'])
@@ -39,45 +40,46 @@ class Decoder(torch.nn.Module):
39
 
40
  def forward(self, x):
41
  x = self.glu(self.conv1(x))
42
- x = self.relu(self.conv2(x))
 
 
43
  return x
44
 
45
 
46
  class Demucs(torch.nn.Module):
47
  def __init__(self, cfg):
48
  super(Demucs, self).__init__()
49
- H = cfg['H']
50
-
51
- self.encoder1 = Encoder(in_channels=1, out_channels=H, cfg=cfg['encoder'])
52
- self.encoder2 = Encoder(in_channels=H, out_channels=2*H, cfg=cfg['encoder'])
53
- self.encoder3 = Encoder(in_channels=2*H, out_channels=4*H, cfg=cfg['encoder'])
 
 
 
 
 
 
 
 
54
 
55
  self.lstm = torch.nn.LSTM(
56
- input_size=4*H,
57
- hidden_size=4*H, num_layers=2, batch_first=True)
58
-
59
- self.decoder1 = Decoder(in_channels=4*H, out_channels=2*H, cfg=cfg['decoder'])
60
- self.decoder2 = Decoder(in_channels=2*H, out_channels=H, cfg=cfg['decoder'])
61
- self.decoder3 = Decoder(in_channels=H, out_channels=1, cfg=cfg['decoder'])
62
 
63
  def forward(self, x):
64
- out1 = self.encoder1(x)
65
- out2 = self.encoder2(out1)
66
- out3 = self.encoder3(out2)
 
 
67
 
68
- x, _ = self.lstm(out3.permute(0, 2, 1))
69
  x = x.permute(0, 2, 1)
70
- x = self.decoder1(x + out3)
71
- x = x[:, :, :out2.shape[-1]]
72
- x = self.decoder2(x + out2)
73
- x = x[:, :, :-1]
74
- out1 = out1[:, :, :-1]
75
- if x.shape[-1] > out1.shape[-1]:
76
- x = x[:, :, :out1.shape[-1]]
77
- elif x.shape[-1] < out1.shape[-1]:
78
- out1 = out1[:, :, :x.shape[-1]]
79
 
80
- x = self.decoder3(x + out1)
 
 
 
 
81
  return x
82
-
83
-
 
1
  import torch
2
  from torch.nn.functional import pad
3
+ from utils import pad_cut_batch_audio
4
+ import torch.nn as nn
5
 
6
 
7
  class Encoder(torch.nn.Module):
 
26
 
27
 
28
  class Decoder(torch.nn.Module):
29
+ def __init__(self, in_channels, out_channels, cfg, is_last=False):
30
  super(Decoder, self).__init__()
31
+ self.is_last = is_last
32
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
33
  kernel_size=cfg['conv1']['kernel_size'],
34
  stride=cfg['conv1']['stride'])
 
40
 
41
  def forward(self, x):
42
  x = self.glu(self.conv1(x))
43
+ x = self.conv2(x)
44
+ if not self.is_last:
45
+ x = self.relu(x)
46
  return x
47
 
48
 
49
  class Demucs(torch.nn.Module):
50
  def __init__(self, cfg):
51
  super(Demucs, self).__init__()
52
+ self.L = cfg['L']
53
+
54
+ encoders = [Encoder(in_channels=1, out_channels=cfg['H'], cfg=cfg['encoder'])]
55
+ decoders = [Decoder(in_channels=cfg['H'], out_channels=1, cfg=cfg['decoder'], is_last=True)]
56
+ for i in range(self.L - 1):
57
+ encoders.append(Encoder(in_channels=(2 ** i) * cfg['H'],
58
+ out_channels=(2 ** (i + 1)) * cfg['H'],
59
+ cfg=cfg['encoder']))
60
+ decoders.append(Decoder(in_channels=(2 ** (i + 1)) * cfg['H'],
61
+ out_channels=(2 ** i) * cfg['H'],
62
+ cfg=cfg['decoder']))
63
+ self.encoders = nn.ModuleList(encoders)
64
+ self.decoders = nn.ModuleList(decoders)
65
 
66
  self.lstm = torch.nn.LSTM(
67
+ input_size=(2 ** (self.L - 1)) * cfg['H'],
68
+ hidden_size=(2 ** (self.L - 1)) * cfg['H'], num_layers=2, batch_first=True)
 
 
 
 
69
 
70
  def forward(self, x):
71
+ outs = [x]
72
+ for i in range(self.L):
73
+ out = self.encoders[i](outs[-1])
74
+ outs.append(out)
75
+ model_input = outs.pop(0)
76
 
77
+ x, _ = self.lstm(outs[-1].permute(0, 2, 1))
78
  x = x.permute(0, 2, 1)
 
 
 
 
 
 
 
 
 
79
 
80
+ for i in reversed(range(self.L)):
81
+ decoder = self.decoders[i]
82
+ x = pad_cut_batch_audio(x, outs[i].shape)
83
+ x = decoder(x + outs[i])
84
+ x = pad_cut_batch_audio(x, model_input.shape)
85
  return x
 
 
losses.py CHANGED
@@ -107,9 +107,9 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
107
  """Multi resolution STFT loss module."""
108
 
109
  def __init__(self,
110
- fft_sizes=[1024, 2048, 512],
111
- hop_sizes=[120, 240, 50],
112
- win_lengths=[600, 1200, 240],
113
  window="hann_window", factor_sc=0.1, factor_mag=0.1):
114
  """Initialize Multi resolution STFT loss module.
115
  Args:
 
107
  """Multi resolution STFT loss module."""
108
 
109
  def __init__(self,
110
+ fft_sizes=(1024, 2048, 512),
111
+ hop_sizes=(120, 240, 50),
112
+ win_lengths=(600, 1200, 240),
113
  window="hann_window", factor_sc=0.1, factor_mag=0.1):
114
  """Initialize Multi resolution STFT loss module.
115
  Args:
main.py CHANGED
@@ -3,7 +3,6 @@ from omegaconf import DictConfig
3
  from train import train
4
 
5
 
6
-
7
  @hydra.main(version_base=None, config_path="conf", config_name="config")
8
  def main(cfg: DictConfig):
9
  train(cfg)
 
3
  from train import train
4
 
5
 
 
6
  @hydra.main(version_base=None, config_path="conf", config_name="config")
7
  def main(cfg: DictConfig):
8
  train(cfg)
testing/metrics.py CHANGED
@@ -1,19 +1,27 @@
1
  import pesq
2
  from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
3
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
 
4
  import torch
5
  import torchaudio
6
  from torchmetrics import SignalNoiseRatio
7
 
8
- class Metrics:
9
- def __init__(self, rate=16000):
10
- self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
11
- self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
 
 
 
 
12
  self.snr = SignalNoiseRatio()
13
 
14
- def calculate(self, denoised, clean):
15
  pesq_scores, stoi_scores = 0, 0
16
  for denoised_wav, clean_wav in zip(denoised, clean):
 
 
 
17
  try:
18
  pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
19
  stoi_scores += self.stoi(denoised_wav, clean_wav).item()
@@ -22,7 +30,6 @@ class Metrics:
22
  except ValueError as e:
23
  print(e)
24
 
25
-
26
  return {'PESQ': pesq_scores,
27
  'STOI': stoi_scores}
28
 
 
1
  import pesq
2
  from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
3
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
4
+ from torchaudio.transforms import Resample
5
  import torch
6
  import torchaudio
7
  from torchmetrics import SignalNoiseRatio
8
 
9
+ class Metrics(torch.nn.Module):
10
+ def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.source_rate = source_rate
13
+ self.target_rate = target_rate
14
+ self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
15
+ self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
16
+ self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
17
  self.snr = SignalNoiseRatio()
18
 
19
+ def forward(self, denoised, clean):
20
  pesq_scores, stoi_scores = 0, 0
21
  for denoised_wav, clean_wav in zip(denoised, clean):
22
+ if self.source_rate != self.target_rate:
23
+ denoised_wav = self.resampler(denoised_wav)
24
+ clean_wav = self.resampler(clean_wav)
25
  try:
26
  pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
27
  stoi_scores += self.stoi(denoised_wav, clean_wav).item()
 
30
  except ValueError as e:
31
  print(e)
32
 
 
33
  return {'PESQ': pesq_scores,
34
  'STOI': stoi_scores}
35