BorisovMaksim commited on
Commit
bd0a813
·
1 Parent(s): 3f204d4
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  .ipynb_checkpoints/**
3
  nohup.out
4
  __pycache__/**
 
 
2
  .ipynb_checkpoints/**
3
  nohup.out
4
  __pycache__/**
5
+ cache_wav/
EDA.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -2,9 +2,9 @@
2
 
3
 
4
  # Testing
5
- | | valentini_PESQ | valentini_STOI |
6
- |:---------------:|:--------------------:|:--------------:|
7
- | ideal denoising | 1.9709 | 0.9211 |
8
- | baseline | 1.5693 | 0.9504 |
9
 
10
 
 
2
 
3
 
4
  # Testing
5
+ | | valentini_PESQ | valentini_STOI |
6
+ |:---------------:|:--------------:|:--------------:|
7
+ | ideal denoising | 1.9709 | 0.9211 |
8
+ | baseline | 1.7433 | 0.8844 |
9
 
10
 
app.py CHANGED
@@ -9,35 +9,37 @@ import logging
9
  import gradio as gr
10
 
11
 
 
12
  from denoisers.SpectralGating import SpectralGating
13
 
14
  model = SpectralGating()
15
 
16
 
17
  def denoising_transform(audio):
18
- src_path = "cache_wav/source/{}.wav".format(str(uuid.uuid4()))
19
- tgt_path = "cache_wav/target/{}.wav".format(str(uuid.uuid4()))
20
- # os.rename(audio.name, src_path)
21
  (ffmpeg.input(audio)
22
  .output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
23
  .run()
24
  )
25
-
26
- model.predict(src_path, tgt_path)
27
  return tgt_path
 
 
28
 
29
 
30
  inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
31
  outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
32
 
33
- title = "Chinese-to-English Direct Speech-to-Speech Translation (BETA)"
34
  #"""
35
  gr.Interface(
36
  denoising_transform, inputs, outputs, title=title,
37
- allow_flagging='never',
38
  ).launch(
39
  server_name='localhost',
40
  server_port=7871,
 
41
  #ssl_keyfile='example.key',
42
  #ssl_certfile="example.crt",
43
  )
 
9
  import gradio as gr
10
 
11
 
12
+
13
  from denoisers.SpectralGating import SpectralGating
14
 
15
  model = SpectralGating()
16
 
17
 
18
  def denoising_transform(audio):
19
+ src_path = "cache_wav/original/{}.wav".format(str(uuid.uuid4()))
20
+ tgt_path = "cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))
 
21
  (ffmpeg.input(audio)
22
  .output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
23
  .run()
24
  )
25
+ model.predict(audio, tgt_path)
 
26
  return tgt_path
27
+ # model.predict(src_path, tgt_path)
28
+ # return tgt_path
29
 
30
 
31
  inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
32
  outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
33
 
34
+ title = "Denoising"
35
  #"""
36
  gr.Interface(
37
  denoising_transform, inputs, outputs, title=title,
38
+ allow_flagging='never'
39
  ).launch(
40
  server_name='localhost',
41
  server_port=7871,
42
+ share=True
43
  #ssl_keyfile='example.key',
44
  #ssl_certfile="example.crt",
45
  )
datasets.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from pathlib import Path
4
+ from utils import load_wav
5
+
6
+
7
+ class Valentini(Dataset):
8
+ def __init__(self, dataset_path='/media/public/datasets/denoising/DS_10283_2791/', transform=None,
9
+ valid=False):
10
+ clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
11
+ noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
12
+ clean_wavs = list(clean_path.glob("*"))
13
+ noisy_wavs = list(noisy_path.glob("*"))
14
+ valid_threshold = int(len(clean_wavs) * 0.2)
15
+ if valid:
16
+ self.clean_wavs = clean_wavs[:valid_threshold]
17
+ self.noisy_wavs = noisy_wavs[:valid_threshold]
18
+ else:
19
+ self.clean_wavs = clean_wavs[valid_threshold:]
20
+ self.noisy_wavs = noisy_wavs[valid_threshold:]
21
+
22
+ assert len(self.clean_wavs) == len(self.noisy_wavs)
23
+
24
+ self.transform = transform
25
+
26
+ def __len__(self):
27
+ return len(self.clean_wavs)
28
+
29
+ def __getitem__(self, idx):
30
+ noisy_wav = load_wav(self.noisy_wavs[idx])
31
+ clean_wav = load_wav(self.clean_wavs[idx])
32
+
33
+ if self.transform:
34
+ random_seed = torch.randint(100, (1,))[0]
35
+ torch.manual_seed(random_seed)
36
+ noisy_wav = self.transform(noisy_wav)
37
+ torch.manual_seed(random_seed)
38
+ clean_wav = self.transform(clean_wav)
39
+ return noisy_wav, clean_wav
denoisers/SpectralGating.py CHANGED
@@ -16,7 +16,7 @@ class SpectralGating(torch.nn.Module):
16
  data, rate = torchaudio.load(wav_path)
17
  reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
18
  torchaudio.save(out_path, reduced_noise, rate)
19
- return reduced_noise
20
 
21
 
22
 
 
16
  data, rate = torchaudio.load(wav_path)
17
  reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
18
  torchaudio.save(out_path, reduced_noise, rate)
19
+ return out_path
20
 
21
 
22
 
denoisers/__pycache__/SpectralGating.cpython-38.pyc ADDED
Binary file (1.08 kB). View file
 
denoisers/demucs.py CHANGED
@@ -1,36 +1,34 @@
1
  import torch
2
-
3
 
4
  class Encoder(torch.nn.Module):
5
- def __init__(self, in_channels, out_channels,
6
- kernel_size_1=8, stride_1=4,
7
- kernel_size_2=1, stride_2=1):
8
  super(Encoder, self).__init__()
9
 
10
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
11
- kernel_size=kernel_size_1, stride=stride_1)
12
  self.relu1 = torch.nn.ReLU()
13
  self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
14
- kernel_size=kernel_size_2, stride=stride_2)
15
- self.glu = torch.nn.GLU()
16
 
17
  def forward(self, x):
18
  x = self.relu1(self.conv1(x))
 
 
19
  x = self.glu(self.conv2(x))
20
  return x
21
 
22
 
23
  class Decoder(torch.nn.Module):
24
- def __init__(self, in_channels, out_channels,
25
- kernel_size_1=3, stride_1=1,
26
- kernel_size_2=8, stride_2=4):
27
  super(Decoder, self).__init__()
28
 
29
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
30
- kernel_size=kernel_size_1, stride=stride_1)
31
- self.glu = torch.nn.GLU()
32
  self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
33
- kernel_size=kernel_size_2, stride=stride_2)
34
  self.relu = torch.nn.ReLU()
35
 
36
  def forward(self, x):
@@ -40,28 +38,39 @@ class Decoder(torch.nn.Module):
40
 
41
 
42
  class Demucs(torch.nn.Module):
43
- def __init__(self):
44
  super(Demucs, self).__init__()
45
 
46
- self.encoder1 = Encoder(in_channels=1, out_channels=64)
47
- self.encoder2 = Encoder(in_channels=64, out_channels=128)
48
- self.encoder3 = Encoder(in_channels=128, out_channels=256)
49
 
50
- self.lstm = torch.nn.LSTM(input_size=256, hidden_size=256, num_layers=2)
 
 
51
 
52
- self.decoder1 = Decoder(in_channels=256, out_channels=128)
53
- self.decoder2 = Decoder(in_channels=128, out_channels=64)
54
- self.decoder3 = Decoder(in_channels=64, out_channels=1)
55
 
56
  def forward(self, x):
57
  out1 = self.encoder1(x)
58
  out2 = self.encoder2(out1)
59
  out3 = self.encoder3(out2)
60
 
61
- x = self.lstm(out3)
62
-
63
  x = self.decoder1(x + out3)
 
64
  x = self.decoder2(x + out2)
65
- x = self.decoder3(x + out1)
 
 
 
 
 
66
 
 
67
  return x
 
 
 
1
  import torch
2
+ from torch.nn.functional import pad
3
 
4
  class Encoder(torch.nn.Module):
5
+ def __init__(self, in_channels, out_channels):
 
 
6
  super(Encoder, self).__init__()
7
 
8
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
9
+ kernel_size=8, stride=2)
10
  self.relu1 = torch.nn.ReLU()
11
  self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
12
+ kernel_size=1, stride=1)
13
+ self.glu = torch.nn.GLU(dim=-2)
14
 
15
  def forward(self, x):
16
  x = self.relu1(self.conv1(x))
17
+ if x.shape[-1] % 2 == 1:
18
+ x = pad(x, (0, 1))
19
  x = self.glu(self.conv2(x))
20
  return x
21
 
22
 
23
  class Decoder(torch.nn.Module):
24
+ def __init__(self, in_channels, out_channels):
 
 
25
  super(Decoder, self).__init__()
26
 
27
  self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
28
+ kernel_size=1, stride=1)
29
+ self.glu = torch.nn.GLU(dim=-2)
30
  self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
31
+ kernel_size=8, stride=2)
32
  self.relu = torch.nn.ReLU()
33
 
34
  def forward(self, x):
 
38
 
39
 
40
  class Demucs(torch.nn.Module):
41
+ def __init__(self, H):
42
  super(Demucs, self).__init__()
43
 
44
+ self.encoder1 = Encoder(in_channels=1, out_channels=H)
45
+ self.encoder2 = Encoder(in_channels=H, out_channels=2*H)
46
+ self.encoder3 = Encoder(in_channels=2*H, out_channels=4*H)
47
 
48
+ self.lstm = torch.nn.LSTM(
49
+ input_size=4*H,
50
+ hidden_size=4*H, num_layers=2, batch_first=True)
51
 
52
+ self.decoder1 = Decoder(in_channels=4*H, out_channels=2*H)
53
+ self.decoder2 = Decoder(in_channels=2*H, out_channels=H)
54
+ self.decoder3 = Decoder(in_channels=H, out_channels=1)
55
 
56
  def forward(self, x):
57
  out1 = self.encoder1(x)
58
  out2 = self.encoder2(out1)
59
  out3 = self.encoder3(out2)
60
 
61
+ x, _ = self.lstm(out3.permute(0, 2, 1))
62
+ x = x.permute(0, 2, 1)
63
  x = self.decoder1(x + out3)
64
+ x = x[:, :, :out2.shape[-1]]
65
  x = self.decoder2(x + out2)
66
+ x = x[:, :, :-1]
67
+ out1 = out1[:, :, :-1]
68
+ if x.shape[-1] > out1.shape[-1]:
69
+ x = x[:, :, :out1.shape[-1]]
70
+ elif x.shape[-1] < out1.shape[-1]:
71
+ out1 = out1[:, :, :x.shape[-1]]
72
 
73
+ x = self.decoder3(x + out1)
74
  return x
75
+
76
+
evaluation.py CHANGED
@@ -28,10 +28,10 @@ def evaluate_on_dataset(model_name, dataset_path, dataset_type):
28
  noisy_wav = load_wav(noisy_path)
29
 
30
  if model_name is None:
31
- scores = metrics.calculate(noisy_wav, clean_wav)
32
  else:
33
  denoised_wav = model(noisy_wav)
34
- scores = metrics.calculate(noisy_wav, denoised_wav)
35
 
36
  mean_scores['PESQ'] += scores['PESQ']
37
  mean_scores['STOI'] += scores['STOI']
 
28
  noisy_wav = load_wav(noisy_path)
29
 
30
  if model_name is None:
31
+ scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
32
  else:
33
  denoised_wav = model(noisy_wav)
34
+ scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)
35
 
36
  mean_scores['PESQ'] += scores['PESQ']
37
  mean_scores['STOI'] += scores['STOI']
metrics.py CHANGED
@@ -2,16 +2,17 @@ from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
2
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
3
  import torch
4
  import torchaudio
5
-
6
 
7
 
8
  class Metrics:
9
  def __init__(self, rate=16000):
10
  self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
11
  self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
 
12
 
13
- def calculate(self, preds, target):
14
- return {'PESQ': self.nb_pesq(preds, target),
15
- 'STOI': self.stoi(preds, target)}
16
 
17
 
 
2
  from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
3
  import torch
4
  import torchaudio
5
+ from torchmetrics import SignalNoiseRatio
6
 
7
 
8
  class Metrics:
9
  def __init__(self, rate=16000):
10
  self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
11
  self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
12
+ self.snr = SignalNoiseRatio()
13
 
14
+ def calculate(self, denoised, clean):
15
+ return {'PESQ': self.nb_pesq(denoised, clean),
16
+ 'STOI': self.stoi(denoised, clean)}
17
 
18
 
train.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.tensorboard import SummaryWriter
3
+ import torch
4
+ from torch.nn import Sequential
5
+ from torch.utils.data import DataLoader
6
+ from datasets import Valentini
7
+ from datetime import datetime
8
+ from torchvision.transforms import RandomCrop
9
+ from utils import load_wav
10
+ from denoisers.demucs import Demucs
11
+ from pathlib import Path
12
+
13
+ os.environ['CUDA_VISIBLE_DEVICES'] = "1"
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ model = Demucs(H=64).to(device)
18
+
19
+ DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
20
+ VALID_WAVS = {'hard': 'p257_171.wav',
21
+ 'medium': 'p232_071.wav',
22
+ 'easy': 'p232_284.wav'}
23
+ MAX_SECONDS = 3.2
24
+ SAMPLE_RATE = 16000
25
+
26
+ transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
27
+
28
+ training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
29
+ validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
30
+
31
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
32
+ loss_fn = torch.nn.MSELoss()
33
+
34
+
35
+ def train_one_epoch(epoch_index, tb_writer):
36
+ running_loss = 0.
37
+ last_loss = 0.
38
+
39
+ for i, data in enumerate(training_loader):
40
+ inputs, labels = data
41
+ inputs, labels = inputs.to(device), labels.to(device)
42
+
43
+ optimizer.zero_grad()
44
+
45
+ outputs = model(inputs)
46
+
47
+ loss = loss_fn(outputs, labels)
48
+ loss.backward()
49
+
50
+ optimizer.step()
51
+
52
+ running_loss += loss.item()
53
+ if i % 1000 == 999:
54
+ last_loss = running_loss / 100 # loss per batch
55
+ print(' batch {} loss: {}'.format(i + 1, last_loss))
56
+ tb_x = epoch_index * len(training_loader) + i + 1
57
+ tb_writer.add_scalar('Loss/train', last_loss, tb_x)
58
+ running_loss = 0.
59
+
60
+
61
+ return last_loss
62
+
63
+
64
+ def train():
65
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
66
+ writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
67
+ epoch_number = 0
68
+
69
+ EPOCHS = 5
70
+
71
+ best_vloss = 1_000_000.
72
+
73
+ for tag, wav_path in VALID_WAVS.items():
74
+ wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
75
+ writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
76
+ writer.flush()
77
+
78
+ for epoch in range(EPOCHS):
79
+ print('EPOCH {}:'.format(epoch_number + 1))
80
+
81
+ # Make sure gradient tracking is on, and do a pass over the data
82
+ model.train(True)
83
+ avg_loss = train_one_epoch(epoch_number, writer)
84
+
85
+ # We don't need gradients on to do reporting
86
+ model.train(False)
87
+
88
+ running_vloss = 0.0
89
+ with torch.no_grad():
90
+ for i, vdata in enumerate(validation_loader):
91
+ vinputs, vlabels = vdata
92
+ vinputs, vlabels = vinputs.to(device), vlabels.to(device)
93
+ voutputs = model(vinputs)
94
+ vloss = loss_fn(voutputs, vlabels)
95
+ running_vloss += vloss
96
+
97
+ avg_vloss = running_vloss / (i + 1)
98
+ print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
99
+
100
+ writer.add_scalars('Training vs. Validation Loss',
101
+ {'Training': avg_loss, 'Validation': avg_vloss},
102
+ epoch_number + 1)
103
+ for tag, wav_path in VALID_WAVS.items():
104
+ wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
105
+ wav = torch.reshape(wav, (1, 1, -1)).to(device)
106
+ prediction = model(wav)
107
+ writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
108
+ snd_tensor=prediction,
109
+ sample_rate=SAMPLE_RATE)
110
+ writer.flush()
111
+
112
+ if avg_vloss < best_vloss:
113
+ best_vloss = avg_vloss
114
+ model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
115
+ torch.save(model.state_dict(), model_path)
116
+
117
+ epoch_number += 1
118
+
119
+
120
+ if __name__ == '__main__':
121
+ train()