BorisovMaksim commited on
Commit
c45e107
·
1 Parent(s): ba6efe9

deleted files that are not needed for app

Browse files
.github/workflows/python-app.yml DELETED
@@ -1,39 +0,0 @@
1
- # This workflow will install Python dependencies, run tests and lint with a single version of Python
2
- # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
-
4
- name: Python application
5
-
6
- on:
7
- push:
8
- branches: [ "main" ]
9
- pull_request:
10
- branches: [ "main" ]
11
-
12
- permissions:
13
- contents: read
14
-
15
- jobs:
16
- build:
17
-
18
- runs-on: ubuntu-latest
19
-
20
- steps:
21
- - uses: actions/checkout@v3
22
- - name: Set up Python 3.8
23
- uses: actions/setup-python@v3
24
- with:
25
- python-version: "3.8"
26
- - name: Install dependencies
27
- run: |
28
- python -m pip install --upgrade pip
29
- pip install flake8 pytest
30
- if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31
- - name: Lint with flake8
32
- run: |
33
- # stop the build if there are Python syntax errors or undefined names
34
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37
- - name: Test with pytest
38
- run: |
39
- pytest
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -10,6 +10,20 @@ import torchaudio
10
  import yaml
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def run_app(model_filename, config_filename):
15
  model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
@@ -20,37 +34,27 @@ def run_app(model_filename, config_filename):
20
  checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
21
  model.load_state_dict(checkpoint['model_state_dict'])
22
 
23
- def denoising_transform(audio):
24
- # Path(__file__).parent.resolve()
25
- src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
26
- tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
27
- src_path.parent.mkdir(exist_ok=True, parents=True)
28
- tgt_path.parent.mkdir(exist_ok=True, parents=True)
29
- (ffmpeg.input(audio)
30
- .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050)
31
- .run()
32
- )
33
- wav, rate = torchaudio.load(audio)
34
- reduced_noise = model.predict(wav)
35
- torchaudio.save(tgt_path, reduced_noise, rate)
36
- return tgt_path
37
-
38
- demo = gr.Interface(
39
- fn=denoising_transform,
40
  inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
41
- outputs=gr.Audio(label="Target Audio", type='filepath'),
42
- examples=[
43
- ["testing/wavs/p232_071.wav"],
44
- ["testing/wavs/p232_284.wav"],
45
- ],
46
- title="Denoising"
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
- demo.launch()
51
-
52
  if __name__ == "__main__":
53
  model_filename = "original_sr/Demucs_original_sr_epoch3.pt"
54
  config_filename = "original_sr/config.yaml"
55
  run_app(model_filename, config_filename)
56
-
 
10
  import yaml
11
 
12
 
13
+ def denoising_transform(audio, model):
14
+ src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
15
+ tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
16
+ src_path.parent.mkdir(exist_ok=True, parents=True)
17
+ tgt_path.parent.mkdir(exist_ok=True, parents=True)
18
+ (ffmpeg.input(audio)
19
+ .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050)
20
+ .run()
21
+ )
22
+ wav, rate = torchaudio.load(audio)
23
+ reduced_noise = model.predict(wav)
24
+ torchaudio.save(tgt_path, reduced_noise, rate)
25
+ return tgt_path
26
+
27
 
28
  def run_app(model_filename, config_filename):
29
  model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
 
34
  checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
35
  model.load_state_dict(checkpoint['model_state_dict'])
36
 
37
+ interface_demucs = gr.Interface(
38
+ fn=lambda x: denoising_transform(x, model),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
40
+ outputs=gr.Audio(label="Demucs", type='filepath'),
41
+ allow_flagging='never'
 
 
 
 
42
  )
43
+ interface_spectral_gating = gr.Interface(
44
+ fn=lambda x: denoising_transform(x, SpectralGating()),
45
+ inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
46
+ outputs=gr.Audio(label="Spectral Gating", type='filepath'),
47
+ allow_flagging='never'
48
+ )
49
+ gr.Parallel(interface_demucs, interface_spectral_gating,
50
+ title="Denoising",
51
+ examples=[
52
+ ["testing/wavs/p232_071.wav"],
53
+ ["testing/wavs/p232_284.wav"],
54
+ ]).launch()
55
 
56
 
 
 
57
  if __name__ == "__main__":
58
  model_filename = "original_sr/Demucs_original_sr_epoch3.pt"
59
  config_filename = "original_sr/config.yaml"
60
  run_app(model_filename, config_filename)
 
checkpoing_saver.py DELETED
@@ -1,53 +0,0 @@
1
- import os
2
- import numpy as np
3
- import logging
4
- import torch
5
- import wandb
6
-
7
- class CheckpointSaver:
8
- def __init__(self, dirpath, run_name='', decreasing=True, top_n=5):
9
- """
10
- dirpath: Directory path where to store all model weights
11
- decreasing: If decreasing is `True`, then lower metric is better
12
- top_n: Total number of models to track based on validation metric value
13
- """
14
- if not os.path.exists(dirpath): os.makedirs(dirpath)
15
- self.dirpath = dirpath
16
- self.top_n = top_n
17
- self.decreasing = decreasing
18
- self.top_model_paths = []
19
- self.best_metric_val = np.Inf if decreasing else -np.Inf
20
- self.run_name = run_name
21
-
22
-
23
- def __call__(self, model, epoch, metric_val, optimizer, loss):
24
- model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt')
25
- save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
26
- if save:
27
- logging.info(
28
- f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
29
- self.best_metric_val = metric_val
30
- torch.save(
31
- { # Save our checkpoint loc
32
- 'epoch': epoch,
33
- 'model_state_dict': model.state_dict(),
34
- 'optimizer_state_dict': optimizer.state_dict(),
35
- 'loss': loss,
36
- }, model_path)
37
- self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
38
- self.top_model_paths.append({'path': model_path, 'score': metric_val})
39
- self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
40
- if len(self.top_model_paths) > self.top_n:
41
- self.cleanup()
42
-
43
- def log_artifact(self, filename, model_path, metric_val):
44
- artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
45
- artifact.add_file(model_path)
46
- wandb.run.log_artifact(artifact)
47
-
48
- def cleanup(self):
49
- to_remove = self.top_model_paths[self.top_n:]
50
- logging.info(f"Removing extra models.. {to_remove}")
51
- for o in to_remove:
52
- os.remove(o['path'])
53
- self.top_model_paths = self.top_model_paths[:self.top_n]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conf/config.yaml DELETED
@@ -1,34 +0,0 @@
1
- defaults:
2
- - dataset: valentini
3
- - model: demucs
4
- - loss: L1_Multi_STFT
5
- - optimizer: sgd
6
-
7
- training:
8
- num_epochs: 10
9
- model_save_path: /media/public/checkpoints
10
-
11
- dataloader:
12
- max_seconds: 2
13
- sample_rate: 16000
14
- train_batch_size: 12
15
- valid_batch_size: 12
16
- normalize: False
17
- num_workers: 12
18
-
19
- validation:
20
- path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
21
- sample_rate: 48000
22
-
23
-
24
- wandb:
25
- entity: borisovmaksim
26
- run_name: default
27
- project: denoising
28
- log_interval: 100
29
- api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
30
- host: http://localhost:8080/
31
- notes: "Experiment note"
32
- resume: False
33
-
34
- gpu: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conf/dataset/valentini.yaml DELETED
@@ -1,4 +0,0 @@
1
- valentini:
2
- dataset_path: /media/public/datasets/denoising/DS_10283_2791/
3
- val_fraction: 0.2
4
- sample_rate: 48000
 
 
 
 
 
conf/loss/L1_Multi_STFT.yaml DELETED
@@ -1 +0,0 @@
1
- name: L1_Multi_STFT
 
 
conf/loss/mse.yaml DELETED
@@ -1 +0,0 @@
1
- name: mse
 
 
conf/model/demucs.yaml DELETED
@@ -1,19 +0,0 @@
1
- demucs:
2
- H: 64
3
- L: 3
4
- encoder:
5
- conv1:
6
- kernel_size: 8
7
- stride: 2
8
- conv2:
9
- kernel_size: 1
10
- stride: 1
11
-
12
- decoder:
13
- conv1:
14
- kernel_size: 1
15
- stride: 1
16
- conv2:
17
- kernel_size: 8
18
- stride: 2
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conf/optimizer/adam.yaml DELETED
File without changes
conf/optimizer/sgd.yaml DELETED
@@ -1,3 +0,0 @@
1
- sgd:
2
- lr: 0.001
3
- momentum: 0.9
 
 
 
 
datasets/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- from typing import Tuple
2
- from torch.utils.data import Dataset
3
-
4
- from datasets.valentini import Valentini
5
- from transforms import Transform
6
-
7
- DATASETS_POOL = {
8
- 'valentini': Valentini
9
- }
10
-
11
-
12
- def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
13
- name, dataset_params = list(cfg['dataset'].items())[0]
14
- transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
15
- train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
16
- valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
17
- return train_dataset, valid_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/minimal.py DELETED
@@ -1,23 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- from pathlib import Path
4
- import torchaudio
5
- import numpy as np
6
- from torchaudio.transforms import Resample
7
-
8
-
9
- class Minimal(Dataset):
10
- def __init__(self, cfg):
11
- self.wavs = ['p232_284.wav', 'p232_071.wav', 'p257_171.wav']
12
- self.dataset_path = cfg['validation']['path']
13
- self.target_rate = cfg['dataloader']['sample_rate']
14
- self.resampler = Resample(orig_freq=cfg['validation']['sample_rate'],
15
- new_freq=cfg['dataloader']['sample_rate'])
16
-
17
- def __len__(self):
18
- return len(self.wavs)
19
-
20
- def __getitem__(self, idx):
21
- wav, rate = torchaudio.load(Path(self.dataset_path) / self.wavs[idx])
22
- wav = self.resampler(wav)
23
- return wav, self.target_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datasets/valentini.py DELETED
@@ -1,44 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- from pathlib import Path
4
- import torchaudio
5
-
6
- MAX_RANDOM_SEED = 1000
7
-
8
- class Valentini(Dataset):
9
- def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
10
- clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
11
- noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
12
- clean_wavs = list(clean_path.glob("*"))
13
- noisy_wavs = list(noisy_path.glob("*"))
14
- valid_threshold = int(len(clean_wavs) * val_fraction)
15
- if valid:
16
- self.clean_wavs = clean_wavs[:valid_threshold]
17
- self.noisy_wavs = noisy_wavs[:valid_threshold]
18
- else:
19
- self.clean_wavs = clean_wavs[valid_threshold:]
20
- self.noisy_wavs = noisy_wavs[valid_threshold:]
21
-
22
- assert len(self.clean_wavs) == len(self.noisy_wavs)
23
-
24
- self.transform = transform
25
- self.valid = valid
26
-
27
- def __len__(self):
28
- return len(self.clean_wavs)
29
-
30
- def __getitem__(self, idx):
31
- noisy_wav, noisy_sr = torchaudio.load(self.noisy_wavs[idx])
32
- clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
33
-
34
- if self.transform:
35
- random_seed = 0 if self.valid else torch.randint(MAX_RANDOM_SEED, (1,))[0]
36
- torch.manual_seed(random_seed)
37
- noisy_wav = self.transform(noisy_wav)
38
- torch.manual_seed(random_seed)
39
- clean_wav = self.transform(clean_wav)
40
- return noisy_wav, clean_wav
41
-
42
-
43
-
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
denoisers/SpectralGating.py CHANGED
@@ -4,7 +4,7 @@ import torchaudio
4
 
5
 
6
  class SpectralGating(torch.nn.Module):
7
- def __init__(self, rate=16000):
8
  super(SpectralGating, self).__init__()
9
  self.rate = rate
10
 
@@ -12,11 +12,8 @@ class SpectralGating(torch.nn.Module):
12
  reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
13
  return reduced_noise
14
 
15
- def predict(self, wav_path, out_path):
16
- data, rate = torchaudio.load(wav_path)
17
- reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
18
- torchaudio.save(out_path, reduced_noise, rate)
19
- return out_path
20
 
21
 
22
 
 
4
 
5
 
6
  class SpectralGating(torch.nn.Module):
7
+ def __init__(self, rate=48000):
8
  super(SpectralGating, self).__init__()
9
  self.rate = rate
10
 
 
12
  reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
13
  return reduced_noise
14
 
15
+ def predict(self, wav):
16
+ return self.forward(wav)
 
 
 
17
 
18
 
19
 
losses.py DELETED
@@ -1,185 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
- # All rights reserved.
4
- #
5
- # This source code is licensed under the license found in the
6
- # LICENSE file in the root directory of this source tree.
7
-
8
- # Original copyright 2019 Tomoki Hayashi
9
- # MIT License (https://opensource.org/licenses/MIT)
10
-
11
-
12
- import torch
13
- import torch.nn.functional as F
14
-
15
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
-
17
- """STFT-based Loss modules."""
18
-
19
-
20
- def stft(x, fft_size, hop_size, win_length, window):
21
- """Perform STFT and convert to magnitude spectrogram.
22
- Args:
23
- x (Tensor): Input signal tensor (B, T).
24
- fft_size (int): FFT size.
25
- hop_size (int): Hop size.
26
- win_length (int): Window length.
27
- window (str): Window function type.
28
- Returns:
29
- Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
30
- """
31
- x_stft = torch.stft(x[:, 0, :], fft_size, hop_size, win_length, window, return_complex=True)
32
- x_stft = torch.view_as_real(x_stft)
33
- real = x_stft[..., 0]
34
- imag = x_stft[..., 1]
35
-
36
- # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
37
- return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
38
-
39
-
40
- class SpectralConvergengeLoss(torch.nn.Module):
41
- """Spectral convergence loss module."""
42
-
43
- def __init__(self):
44
- """Initilize spectral convergence loss module."""
45
- super(SpectralConvergengeLoss, self).__init__()
46
-
47
- def forward(self, x_mag, y_mag):
48
- """Calculate forward propagation.
49
- Args:
50
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
51
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
52
- Returns:
53
- Tensor: Spectral convergence loss value.
54
- """
55
- return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
56
-
57
-
58
- class LogSTFTMagnitudeLoss(torch.nn.Module):
59
- """Log STFT magnitude loss module."""
60
-
61
- def __init__(self):
62
- """Initilize los STFT magnitude loss module."""
63
- super(LogSTFTMagnitudeLoss, self).__init__()
64
-
65
- def forward(self, x_mag, y_mag):
66
- """Calculate forward propagation.
67
- Args:
68
- x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
69
- y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
70
- Returns:
71
- Tensor: Log STFT magnitude loss value.
72
- """
73
- return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
74
-
75
-
76
- class STFTLoss(torch.nn.Module):
77
- """STFT loss module."""
78
-
79
- def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
80
- """Initialize STFT loss module."""
81
- super(STFTLoss, self).__init__()
82
- self.fft_size = fft_size
83
- self.shift_size = shift_size
84
- self.win_length = win_length
85
- self.register_buffer("window", getattr(torch, window)(win_length))
86
- self.spectral_convergenge_loss = SpectralConvergengeLoss()
87
- self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
88
-
89
- def forward(self, x, y):
90
- """Calculate forward propagation.
91
- Args:
92
- x (Tensor): Predicted signal (B, T).
93
- y (Tensor): Groundtruth signal (B, T).
94
- Returns:
95
- Tensor: Spectral convergence loss value.
96
- Tensor: Log STFT magnitude loss value.
97
- """
98
- x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
99
- y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
100
- sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
101
- mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
102
-
103
- return sc_loss, mag_loss
104
-
105
-
106
- class MultiResolutionSTFTLoss(torch.nn.Module):
107
- """Multi resolution STFT loss module."""
108
-
109
- def __init__(self,
110
- fft_sizes=(1024, 2048, 512),
111
- hop_sizes=(120, 240, 50),
112
- win_lengths=(600, 1200, 240),
113
- window="hann_window", factor_sc=0.1, factor_mag=0.1):
114
- """Initialize Multi resolution STFT loss module.
115
- Args:
116
- fft_sizes (list): List of FFT sizes.
117
- hop_sizes (list): List of hop sizes.
118
- win_lengths (list): List of window lengths.
119
- window (str): Window function type.
120
- factor (float): a balancing factor across different losses.
121
- """
122
- super(MultiResolutionSTFTLoss, self).__init__()
123
- assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
124
- self.stft_losses = torch.nn.ModuleList()
125
- for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
126
- self.stft_losses += [STFTLoss(fs, ss, wl, window)]
127
- self.factor_sc = factor_sc
128
- self.factor_mag = factor_mag
129
-
130
- def forward(self, x, y):
131
- """Calculate forward propagation.
132
- Args:
133
- x (Tensor): Predicted signal (B, T).
134
- y (Tensor): Groundtruth signal (B, T).
135
- Returns:
136
- Tensor: Multi resolution spectral convergence loss value.
137
- Tensor: Multi resolution log STFT magnitude loss value.
138
- """
139
- sc_loss = 0.0
140
- mag_loss = 0.0
141
- for f in self.stft_losses:
142
- sc_l, mag_l = f(x, y)
143
- sc_loss += sc_l
144
- mag_loss += mag_l
145
- sc_loss /= len(self.stft_losses)
146
- mag_loss /= len(self.stft_losses)
147
-
148
- return self.factor_sc*sc_loss, self.factor_mag*mag_loss
149
-
150
-
151
-
152
-
153
- class L1_Multi_STFT(torch.nn.Module):
154
- """STFT loss module."""
155
-
156
- def __init__(self):
157
- """Initialize STFT loss module."""
158
- super(L1_Multi_STFT, self).__init__()
159
- self.multi_STFT_loss = MultiResolutionSTFTLoss()
160
- self.l1_loss = torch.nn.L1Loss()
161
-
162
- def forward(self, x, y):
163
- """Calculate forward propagation.
164
- Args:
165
- x (Tensor): Predicted signal (B, T).
166
- y (Tensor): Groundtruth signal (B, T).
167
- Returns:
168
- Tensor: Spectral convergence loss value.
169
- Tensor: Log STFT magnitude loss value.
170
- """
171
- sc_loss, mag_loss = self.multi_STFT_loss(x, y)
172
- l1_loss = self.l1_loss(x, y)
173
- return sc_loss + mag_loss + l1_loss
174
-
175
-
176
- LOSSES = {
177
- 'mse': torch.nn.MSELoss(),
178
- 'L1': torch.nn.L1Loss(),
179
- 'Multi_STFT': MultiResolutionSTFTLoss(),
180
- 'L1_Multi_STFT': L1_Multi_STFT()
181
- }
182
-
183
-
184
- def get_loss(loss_config, device):
185
- return LOSSES[loss_config['name']].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,12 +0,0 @@
1
- import hydra
2
- from omegaconf import DictConfig
3
- from train import train
4
-
5
-
6
- @hydra.main(version_base=None, config_path="conf", config_name="config")
7
- def main(cfg: DictConfig):
8
- train(cfg)
9
-
10
-
11
- if __name__ == '__main__':
12
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/EDA.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
optimizers.py DELETED
@@ -1,10 +0,0 @@
1
- import torch
2
-
3
- OPTIMIZERS_POOL = {
4
- 'sgd': torch.optim.SGD,
5
- }
6
-
7
- def get_optimizer(model_params, optimizer_config):
8
- name, params = list(optimizer_config.items())[0]
9
- optimizer = OPTIMIZERS_POOL[name](model_params, **params)
10
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
testing/evaluation.py DELETED
@@ -1,62 +0,0 @@
1
- import argparse
2
- from tqdm import tqdm
3
-
4
- from utils import load_wav, collect_valentini_paths
5
- from metrics import Metrics
6
- from denoisers.SpectralGating import SpectralGating
7
-
8
-
9
- PARSERS = {
10
- 'valentini': collect_valentini_paths
11
- }
12
- MODELS = {
13
- 'baseline': SpectralGating
14
- }
15
-
16
-
17
-
18
- def evaluate_on_dataset(model_name, dataset_path, dataset_type):
19
- if model_name is not None:
20
- model = MODELS[model_name]()
21
- parser = PARSERS[dataset_type]
22
- clean_wavs, noisy_wavs = parser(dataset_path)
23
-
24
- metrics = Metrics()
25
- mean_scores = {'PESQ': 0, 'STOI': 0}
26
- for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)):
27
- clean_wav = load_wav(clean_path)
28
- noisy_wav = load_wav(noisy_path)
29
-
30
- if model_name is None:
31
- scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
32
- else:
33
- denoised_wav = model(noisy_wav)
34
- scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)
35
-
36
- mean_scores['PESQ'] += scores['PESQ']
37
- mean_scores['STOI'] += scores['STOI']
38
-
39
- mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs)
40
- mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs)
41
-
42
- return mean_scores
43
-
44
-
45
- if __name__ == '__main__':
46
- parser = argparse.ArgumentParser(prog='Program to evaluate denoising')
47
- parser.add_argument('--dataset_path', type=str,
48
- default='/media/public/dataset/denoising/DS_10283_2791/',
49
- help='Path to dataset folder')
50
- parser.add_argument('--dataset_type', type=str, required=True,
51
- choices=['valentini'])
52
- parser.add_argument('--model_name', type=str,
53
- choices=['baseline'])
54
-
55
-
56
- args = parser.parse_args()
57
-
58
- mean_scores = evaluate_on_dataset(model_name=args.model_name,
59
- dataset_path=args.dataset_path,
60
- dataset_type=args.dataset_type)
61
- print(f"Metrics on {args.dataset_type} dataset with "
62
- f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
testing/metrics.py DELETED
@@ -1,36 +0,0 @@
1
- import pesq
2
- from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
3
- from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
4
- from torchaudio.transforms import Resample
5
- import torch
6
- import torchaudio
7
- from torchmetrics import SignalNoiseRatio
8
-
9
- class Metrics(torch.nn.Module):
10
- def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
11
- super().__init__(*args, **kwargs)
12
- self.source_rate = source_rate
13
- self.target_rate = target_rate
14
- self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
15
- self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
16
- self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
17
- self.snr = SignalNoiseRatio()
18
-
19
- def forward(self, denoised, clean):
20
- pesq_scores, stoi_scores = 0, 0
21
- for denoised_wav, clean_wav in zip(denoised, clean):
22
- if self.source_rate != self.target_rate:
23
- denoised_wav = self.resampler(denoised_wav)
24
- clean_wav = self.resampler(clean_wav)
25
- try:
26
- pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
27
- stoi_scores += self.stoi(denoised_wav, clean_wav).item()
28
- except pesq.NoUtterancesError as e:
29
- print(e)
30
- except ValueError as e:
31
- print(e)
32
-
33
- return {'PESQ': pesq_scores,
34
- 'STOI': stoi_scores}
35
-
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
testing/test_inference.py DELETED
@@ -1,6 +0,0 @@
1
- def func(x):
2
- return x + 1
3
-
4
-
5
- def test_answer():
6
- assert func(3) == 4
 
 
 
 
 
 
 
train.py DELETED
@@ -1,123 +0,0 @@
1
- import os
2
- import torch
3
- from torch.utils.data import DataLoader
4
- import omegaconf
5
- from omegaconf import DictConfig
6
- import wandb
7
-
8
- from checkpoing_saver import CheckpointSaver
9
- from denoisers import get_model
10
- from optimizers import get_optimizer
11
- from losses import get_loss
12
- from datasets import get_datasets
13
- from testing.metrics import Metrics
14
- from datasets.minimal import Minimal
15
- from tqdm import tqdm
16
-
17
- def init_wandb(cfg):
18
- wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
19
- wandb.init(project=cfg['wandb']['project'],
20
- notes=cfg['wandb']['notes'],
21
- config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
22
- resume=cfg['wandb']['resume'],
23
- name=cfg['wandb']['run_name'])
24
- if wandb.run.resumed:
25
- api = wandb.Api()
26
- runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
27
- order='train_pesq')
28
- run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
29
- artifacts = run.logged_artifacts()
30
- best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
31
-
32
- best_model.download()
33
-
34
- def train(cfg: DictConfig):
35
- device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
36
- init_wandb(cfg)
37
- checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
38
- decreasing=False)
39
- metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
40
-
41
- model = get_model(cfg['model']).to(device)
42
- optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
43
- loss_fn = get_loss(cfg['loss'], device)
44
- train_dataset, valid_dataset = get_datasets(cfg)
45
- minimal_dataset = Minimal(cfg)
46
-
47
- dataloaders = {
48
- 'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
49
- num_workers=cfg['dataloader']['num_workers']),
50
- 'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
51
- num_workers=cfg['dataloader']['num_workers']),
52
- 'minimal': DataLoader(minimal_dataset)
53
- }
54
-
55
- wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
56
- epoch = 0
57
- while epoch < cfg['training']['num_epochs']:
58
- for phase in ['train', 'val']:
59
- if phase == 'train':
60
- model.train()
61
- else:
62
- model.eval()
63
-
64
- running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
65
- loop = tqdm(dataloaders[phase])
66
- for i, (inputs, labels) in enumerate(loop):
67
- inputs = inputs.to(device)
68
- labels = labels.to(device)
69
-
70
- optimizer.zero_grad()
71
-
72
- with torch.set_grad_enabled(phase == 'train'):
73
- outputs = model(inputs)
74
- loss = loss_fn(outputs, labels)
75
-
76
- if phase == 'train':
77
- loss.backward()
78
- optimizer.step()
79
-
80
- running_metrics = metrics(denoised=outputs, clean=labels)
81
- running_loss += loss.item() * inputs.size(0)
82
- running_pesq += running_metrics['PESQ']
83
- running_stoi += running_metrics['STOI']
84
-
85
- loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
86
- loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
87
- pesq=running_pesq / (i + 1) / inputs.size(0),
88
- stoi=running_stoi / (i + 1) / inputs.size(0))
89
-
90
- if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
91
- wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
92
- "train_pesq": running_pesq / (i + 1) / inputs.size(0),
93
- "train_stoi": running_stoi / (i + 1) / inputs.size(0)})
94
-
95
- epoch_loss = running_loss / len(dataloaders[phase].dataset)
96
- eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
97
- eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
98
-
99
- wandb.log({f"{phase}_loss": epoch_loss,
100
- f"{phase}_pesq": eposh_pesq,
101
- f"{phase}_stoi": eposh_stoi})
102
-
103
- if phase == 'val':
104
- for i, (wav, rate) in enumerate(dataloaders['minimal']):
105
- if cfg['dataloader']['normalize']:
106
- std = torch.std(wav)
107
- wav = wav / std
108
- prediction = model(wav.to(device))
109
- prediction = prediction * std
110
- else:
111
- prediction = model(wav.to(device))
112
- wandb.log({
113
- f"{i}_example": wandb.Audio(
114
- prediction.detach().cpu().numpy()[0][0],
115
- sample_rate=rate)})
116
-
117
- checkpoint_saver(model, epoch, metric_val=eposh_pesq,
118
- optimizer=optimizer, loss=epoch_loss)
119
- epoch += 1
120
-
121
-
122
- if __name__ == "__main__":
123
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transforms.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
-
3
- from torchaudio.transforms import Resample
4
- from torchvision.transforms import RandomCrop
5
-
6
-
7
-
8
- class Transform(torch.nn.Module):
9
- def __init__(
10
- self,
11
- input_sample_rate,
12
- sample_rate,
13
- max_seconds,
14
- normalize,
15
- *args,
16
- **kwargs
17
- ):
18
- super().__init__()
19
- self.input_sample_rate = input_sample_rate
20
- self.sample_rate = sample_rate
21
- self.resample = Resample(orig_freq=input_sample_rate, new_freq=sample_rate)
22
- self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
23
- self.normalize = normalize
24
-
25
- def forward(self, waveform: torch.Tensor) -> torch.Tensor:
26
- if self.input_sample_rate != self.sample_rate:
27
- waveform = self.resample(waveform)
28
- if self.normalize:
29
- waveform = waveform / torch.std(waveform)
30
- cropped = self.random_crop(waveform)
31
- return cropped
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
upload_to_HF.py DELETED
@@ -1,56 +0,0 @@
1
- import os
2
- import wandb
3
- from huggingface_hub import HfApi
4
- from pathlib import Path
5
- import huggingface_hub
6
- import ssl
7
- import os
8
- os.environ['CURL_CA_BUNDLE'] = ''
9
-
10
- ssl._create_default_https_context = ssl._create_unverified_context
11
-
12
- class Uploader:
13
- def __init__(self, entity, project, run_name, repo_id, username):
14
- self.entity = entity
15
- self.project = project
16
- self.run_name = run_name
17
- self.hf_api = HfApi()
18
- self.wandb_api = wandb.Api()
19
- self.repo_id = repo_id
20
- self.username = username
21
- huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
22
-
23
- def get_model_from_wandb_run(self):
24
- runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
25
- # order='+summary_metrics.train_pesq'
26
- )
27
- run = [run for run in runs if run.name == self.run_name][0]
28
- artifacts = run.logged_artifacts()
29
- best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
30
- artifact_dir = best_model.download()
31
- model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
32
- print(f"Model validation score = {best_model.metadata['Validation score']}")
33
- return model_path
34
-
35
- def upload_to_HF(self):
36
- model_path = self.get_model_from_wandb_run()
37
- self.hf_api.upload_file(
38
- path_or_fileobj=model_path,
39
- path_in_repo=Path(model_path).name,
40
- repo_id=f'{self.username}/{self.repo_id}',
41
- )
42
-
43
- def create_repo(self):
44
- self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
45
-
46
-
47
-
48
- if __name__ == '__main__':
49
- uploader = Uploader(entity='borisovmaksim',
50
- project='denoising',
51
- run_name='wav_normalization',
52
- repo_id='demucs',
53
- username='BorisovMaksim')
54
- uploader.create_repo()
55
- uploader.upload_to_HF()
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,7 +1,3 @@
1
- import torchaudio
2
- import torch
3
- import matplotlib.pyplot as plt
4
- from pathlib import Path
5
  from torch.nn.functional import pad
6
 
7
 
@@ -17,45 +13,3 @@ def pad_cut_batch_audio(wavs, new_shape):
17
  return wavs
18
 
19
 
20
- def collect_valentini_paths(dataset_path):
21
- clean_path = Path(dataset_path) / 'clean_testset_wav'
22
- noisy_path = Path(dataset_path) / 'noisy_testset_wav'
23
-
24
- clean_wavs = list(clean_path.glob("*"))
25
- noisy_wavs = list(noisy_path.glob("*"))
26
-
27
- return clean_wavs, noisy_wavs
28
-
29
-
30
-
31
-
32
- def plot_spectrogram(stft, title="Spectrogram", xlim=None):
33
- magnitude = stft.abs()
34
- spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
35
- figure, axis = plt.subplots(1, 1)
36
- img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
37
- figure.suptitle(title)
38
- plt.colorbar(img, ax=axis)
39
- plt.show()
40
-
41
-
42
- def plot_mask(mask, title="Mask", xlim=None):
43
- mask = mask.numpy()
44
- figure, axis = plt.subplots(1, 1)
45
- img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
46
- figure.suptitle(title)
47
- plt.colorbar(img, ax=axis)
48
- plt.show()
49
-
50
-
51
-
52
-
53
- def generate_mixture(waveform_clean, waveform_noise, target_snr):
54
-
55
- power_clean_signal = waveform_clean.pow(2).mean()
56
- power_noise_signal = waveform_noise.pow(2).mean()
57
- current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
58
-
59
- waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
60
- return waveform_clean + waveform_noise
61
-
 
 
 
 
 
1
  from torch.nn.functional import pad
2
 
3
 
 
13
  return wavs
14
 
15