Spaces:
Runtime error
Runtime error
Commit
·
c45e107
1
Parent(s):
ba6efe9
deleted files that are not needed for app
Browse files- .github/workflows/python-app.yml +0 -39
- app.py +30 -26
- checkpoing_saver.py +0 -53
- conf/config.yaml +0 -34
- conf/dataset/valentini.yaml +0 -4
- conf/loss/L1_Multi_STFT.yaml +0 -1
- conf/loss/mse.yaml +0 -1
- conf/model/demucs.yaml +0 -19
- conf/optimizer/adam.yaml +0 -0
- conf/optimizer/sgd.yaml +0 -3
- datasets/__init__.py +0 -17
- datasets/minimal.py +0 -23
- datasets/valentini.py +0 -44
- denoisers/SpectralGating.py +3 -6
- losses.py +0 -185
- main.py +0 -12
- notebooks/EDA.ipynb +0 -0
- optimizers.py +0 -10
- testing/evaluation.py +0 -62
- testing/metrics.py +0 -36
- testing/test_inference.py +0 -6
- train.py +0 -123
- transforms.py +0 -31
- upload_to_HF.py +0 -56
- utils.py +0 -46
.github/workflows/python-app.yml
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
2 |
-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
3 |
-
|
4 |
-
name: Python application
|
5 |
-
|
6 |
-
on:
|
7 |
-
push:
|
8 |
-
branches: [ "main" ]
|
9 |
-
pull_request:
|
10 |
-
branches: [ "main" ]
|
11 |
-
|
12 |
-
permissions:
|
13 |
-
contents: read
|
14 |
-
|
15 |
-
jobs:
|
16 |
-
build:
|
17 |
-
|
18 |
-
runs-on: ubuntu-latest
|
19 |
-
|
20 |
-
steps:
|
21 |
-
- uses: actions/checkout@v3
|
22 |
-
- name: Set up Python 3.8
|
23 |
-
uses: actions/setup-python@v3
|
24 |
-
with:
|
25 |
-
python-version: "3.8"
|
26 |
-
- name: Install dependencies
|
27 |
-
run: |
|
28 |
-
python -m pip install --upgrade pip
|
29 |
-
pip install flake8 pytest
|
30 |
-
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
31 |
-
- name: Lint with flake8
|
32 |
-
run: |
|
33 |
-
# stop the build if there are Python syntax errors or undefined names
|
34 |
-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
35 |
-
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
36 |
-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
37 |
-
- name: Test with pytest
|
38 |
-
run: |
|
39 |
-
pytest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -10,6 +10,20 @@ import torchaudio
|
|
10 |
import yaml
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def run_app(model_filename, config_filename):
|
15 |
model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
|
@@ -20,37 +34,27 @@ def run_app(model_filename, config_filename):
|
|
20 |
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
21 |
model.load_state_dict(checkpoint['model_state_dict'])
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
26 |
-
tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
|
27 |
-
src_path.parent.mkdir(exist_ok=True, parents=True)
|
28 |
-
tgt_path.parent.mkdir(exist_ok=True, parents=True)
|
29 |
-
(ffmpeg.input(audio)
|
30 |
-
.output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050)
|
31 |
-
.run()
|
32 |
-
)
|
33 |
-
wav, rate = torchaudio.load(audio)
|
34 |
-
reduced_noise = model.predict(wav)
|
35 |
-
torchaudio.save(tgt_path, reduced_noise, rate)
|
36 |
-
return tgt_path
|
37 |
-
|
38 |
-
demo = gr.Interface(
|
39 |
-
fn=denoising_transform,
|
40 |
inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
|
41 |
-
outputs=gr.Audio(label="
|
42 |
-
|
43 |
-
["testing/wavs/p232_071.wav"],
|
44 |
-
["testing/wavs/p232_284.wav"],
|
45 |
-
],
|
46 |
-
title="Denoising"
|
47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
-
demo.launch()
|
51 |
-
|
52 |
if __name__ == "__main__":
|
53 |
model_filename = "original_sr/Demucs_original_sr_epoch3.pt"
|
54 |
config_filename = "original_sr/config.yaml"
|
55 |
run_app(model_filename, config_filename)
|
56 |
-
|
|
|
10 |
import yaml
|
11 |
|
12 |
|
13 |
+
def denoising_transform(audio, model):
|
14 |
+
src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
15 |
+
tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
|
16 |
+
src_path.parent.mkdir(exist_ok=True, parents=True)
|
17 |
+
tgt_path.parent.mkdir(exist_ok=True, parents=True)
|
18 |
+
(ffmpeg.input(audio)
|
19 |
+
.output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050)
|
20 |
+
.run()
|
21 |
+
)
|
22 |
+
wav, rate = torchaudio.load(audio)
|
23 |
+
reduced_noise = model.predict(wav)
|
24 |
+
torchaudio.save(tgt_path, reduced_noise, rate)
|
25 |
+
return tgt_path
|
26 |
+
|
27 |
|
28 |
def run_app(model_filename, config_filename):
|
29 |
model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
|
|
|
34 |
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
35 |
model.load_state_dict(checkpoint['model_state_dict'])
|
36 |
|
37 |
+
interface_demucs = gr.Interface(
|
38 |
+
fn=lambda x: denoising_transform(x, model),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
|
40 |
+
outputs=gr.Audio(label="Demucs", type='filepath'),
|
41 |
+
allow_flagging='never'
|
|
|
|
|
|
|
|
|
42 |
)
|
43 |
+
interface_spectral_gating = gr.Interface(
|
44 |
+
fn=lambda x: denoising_transform(x, SpectralGating()),
|
45 |
+
inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
|
46 |
+
outputs=gr.Audio(label="Spectral Gating", type='filepath'),
|
47 |
+
allow_flagging='never'
|
48 |
+
)
|
49 |
+
gr.Parallel(interface_demucs, interface_spectral_gating,
|
50 |
+
title="Denoising",
|
51 |
+
examples=[
|
52 |
+
["testing/wavs/p232_071.wav"],
|
53 |
+
["testing/wavs/p232_284.wav"],
|
54 |
+
]).launch()
|
55 |
|
56 |
|
|
|
|
|
57 |
if __name__ == "__main__":
|
58 |
model_filename = "original_sr/Demucs_original_sr_epoch3.pt"
|
59 |
config_filename = "original_sr/config.yaml"
|
60 |
run_app(model_filename, config_filename)
|
|
checkpoing_saver.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import numpy as np
|
3 |
-
import logging
|
4 |
-
import torch
|
5 |
-
import wandb
|
6 |
-
|
7 |
-
class CheckpointSaver:
|
8 |
-
def __init__(self, dirpath, run_name='', decreasing=True, top_n=5):
|
9 |
-
"""
|
10 |
-
dirpath: Directory path where to store all model weights
|
11 |
-
decreasing: If decreasing is `True`, then lower metric is better
|
12 |
-
top_n: Total number of models to track based on validation metric value
|
13 |
-
"""
|
14 |
-
if not os.path.exists(dirpath): os.makedirs(dirpath)
|
15 |
-
self.dirpath = dirpath
|
16 |
-
self.top_n = top_n
|
17 |
-
self.decreasing = decreasing
|
18 |
-
self.top_model_paths = []
|
19 |
-
self.best_metric_val = np.Inf if decreasing else -np.Inf
|
20 |
-
self.run_name = run_name
|
21 |
-
|
22 |
-
|
23 |
-
def __call__(self, model, epoch, metric_val, optimizer, loss):
|
24 |
-
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt')
|
25 |
-
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
|
26 |
-
if save:
|
27 |
-
logging.info(
|
28 |
-
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
|
29 |
-
self.best_metric_val = metric_val
|
30 |
-
torch.save(
|
31 |
-
{ # Save our checkpoint loc
|
32 |
-
'epoch': epoch,
|
33 |
-
'model_state_dict': model.state_dict(),
|
34 |
-
'optimizer_state_dict': optimizer.state_dict(),
|
35 |
-
'loss': loss,
|
36 |
-
}, model_path)
|
37 |
-
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
|
38 |
-
self.top_model_paths.append({'path': model_path, 'score': metric_val})
|
39 |
-
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
|
40 |
-
if len(self.top_model_paths) > self.top_n:
|
41 |
-
self.cleanup()
|
42 |
-
|
43 |
-
def log_artifact(self, filename, model_path, metric_val):
|
44 |
-
artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
|
45 |
-
artifact.add_file(model_path)
|
46 |
-
wandb.run.log_artifact(artifact)
|
47 |
-
|
48 |
-
def cleanup(self):
|
49 |
-
to_remove = self.top_model_paths[self.top_n:]
|
50 |
-
logging.info(f"Removing extra models.. {to_remove}")
|
51 |
-
for o in to_remove:
|
52 |
-
os.remove(o['path'])
|
53 |
-
self.top_model_paths = self.top_model_paths[:self.top_n]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/config.yaml
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- dataset: valentini
|
3 |
-
- model: demucs
|
4 |
-
- loss: L1_Multi_STFT
|
5 |
-
- optimizer: sgd
|
6 |
-
|
7 |
-
training:
|
8 |
-
num_epochs: 10
|
9 |
-
model_save_path: /media/public/checkpoints
|
10 |
-
|
11 |
-
dataloader:
|
12 |
-
max_seconds: 2
|
13 |
-
sample_rate: 16000
|
14 |
-
train_batch_size: 12
|
15 |
-
valid_batch_size: 12
|
16 |
-
normalize: False
|
17 |
-
num_workers: 12
|
18 |
-
|
19 |
-
validation:
|
20 |
-
path: /media/public/datasets/denoising/DS_10283_2791/noisy_testset_wav
|
21 |
-
sample_rate: 48000
|
22 |
-
|
23 |
-
|
24 |
-
wandb:
|
25 |
-
entity: borisovmaksim
|
26 |
-
run_name: default
|
27 |
-
project: denoising
|
28 |
-
log_interval: 100
|
29 |
-
api_key: local-e23d01ece807cb31e69b2cf4137e4998e4b9856f
|
30 |
-
host: http://localhost:8080/
|
31 |
-
notes: "Experiment note"
|
32 |
-
resume: False
|
33 |
-
|
34 |
-
gpu: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/dataset/valentini.yaml
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
valentini:
|
2 |
-
dataset_path: /media/public/datasets/denoising/DS_10283_2791/
|
3 |
-
val_fraction: 0.2
|
4 |
-
sample_rate: 48000
|
|
|
|
|
|
|
|
|
|
conf/loss/L1_Multi_STFT.yaml
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
name: L1_Multi_STFT
|
|
|
|
conf/loss/mse.yaml
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
name: mse
|
|
|
|
conf/model/demucs.yaml
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
demucs:
|
2 |
-
H: 64
|
3 |
-
L: 3
|
4 |
-
encoder:
|
5 |
-
conv1:
|
6 |
-
kernel_size: 8
|
7 |
-
stride: 2
|
8 |
-
conv2:
|
9 |
-
kernel_size: 1
|
10 |
-
stride: 1
|
11 |
-
|
12 |
-
decoder:
|
13 |
-
conv1:
|
14 |
-
kernel_size: 1
|
15 |
-
stride: 1
|
16 |
-
conv2:
|
17 |
-
kernel_size: 8
|
18 |
-
stride: 2
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/optimizer/adam.yaml
DELETED
File without changes
|
conf/optimizer/sgd.yaml
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
sgd:
|
2 |
-
lr: 0.001
|
3 |
-
momentum: 0.9
|
|
|
|
|
|
|
|
datasets/__init__.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from typing import Tuple
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
|
4 |
-
from datasets.valentini import Valentini
|
5 |
-
from transforms import Transform
|
6 |
-
|
7 |
-
DATASETS_POOL = {
|
8 |
-
'valentini': Valentini
|
9 |
-
}
|
10 |
-
|
11 |
-
|
12 |
-
def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
|
13 |
-
name, dataset_params = list(cfg['dataset'].items())[0]
|
14 |
-
transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
|
15 |
-
train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
|
16 |
-
valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
|
17 |
-
return train_dataset, valid_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/minimal.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
from pathlib import Path
|
4 |
-
import torchaudio
|
5 |
-
import numpy as np
|
6 |
-
from torchaudio.transforms import Resample
|
7 |
-
|
8 |
-
|
9 |
-
class Minimal(Dataset):
|
10 |
-
def __init__(self, cfg):
|
11 |
-
self.wavs = ['p232_284.wav', 'p232_071.wav', 'p257_171.wav']
|
12 |
-
self.dataset_path = cfg['validation']['path']
|
13 |
-
self.target_rate = cfg['dataloader']['sample_rate']
|
14 |
-
self.resampler = Resample(orig_freq=cfg['validation']['sample_rate'],
|
15 |
-
new_freq=cfg['dataloader']['sample_rate'])
|
16 |
-
|
17 |
-
def __len__(self):
|
18 |
-
return len(self.wavs)
|
19 |
-
|
20 |
-
def __getitem__(self, idx):
|
21 |
-
wav, rate = torchaudio.load(Path(self.dataset_path) / self.wavs[idx])
|
22 |
-
wav = self.resampler(wav)
|
23 |
-
return wav, self.target_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
datasets/valentini.py
DELETED
@@ -1,44 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch.utils.data import Dataset
|
3 |
-
from pathlib import Path
|
4 |
-
import torchaudio
|
5 |
-
|
6 |
-
MAX_RANDOM_SEED = 1000
|
7 |
-
|
8 |
-
class Valentini(Dataset):
|
9 |
-
def __init__(self, dataset_path, val_fraction, transform=None, valid=False, *args, **kwargs):
|
10 |
-
clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
|
11 |
-
noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
|
12 |
-
clean_wavs = list(clean_path.glob("*"))
|
13 |
-
noisy_wavs = list(noisy_path.glob("*"))
|
14 |
-
valid_threshold = int(len(clean_wavs) * val_fraction)
|
15 |
-
if valid:
|
16 |
-
self.clean_wavs = clean_wavs[:valid_threshold]
|
17 |
-
self.noisy_wavs = noisy_wavs[:valid_threshold]
|
18 |
-
else:
|
19 |
-
self.clean_wavs = clean_wavs[valid_threshold:]
|
20 |
-
self.noisy_wavs = noisy_wavs[valid_threshold:]
|
21 |
-
|
22 |
-
assert len(self.clean_wavs) == len(self.noisy_wavs)
|
23 |
-
|
24 |
-
self.transform = transform
|
25 |
-
self.valid = valid
|
26 |
-
|
27 |
-
def __len__(self):
|
28 |
-
return len(self.clean_wavs)
|
29 |
-
|
30 |
-
def __getitem__(self, idx):
|
31 |
-
noisy_wav, noisy_sr = torchaudio.load(self.noisy_wavs[idx])
|
32 |
-
clean_wav, clean_sr = torchaudio.load(self.clean_wavs[idx])
|
33 |
-
|
34 |
-
if self.transform:
|
35 |
-
random_seed = 0 if self.valid else torch.randint(MAX_RANDOM_SEED, (1,))[0]
|
36 |
-
torch.manual_seed(random_seed)
|
37 |
-
noisy_wav = self.transform(noisy_wav)
|
38 |
-
torch.manual_seed(random_seed)
|
39 |
-
clean_wav = self.transform(clean_wav)
|
40 |
-
return noisy_wav, clean_wav
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
denoisers/SpectralGating.py
CHANGED
@@ -4,7 +4,7 @@ import torchaudio
|
|
4 |
|
5 |
|
6 |
class SpectralGating(torch.nn.Module):
|
7 |
-
def __init__(self, rate=
|
8 |
super(SpectralGating, self).__init__()
|
9 |
self.rate = rate
|
10 |
|
@@ -12,11 +12,8 @@ class SpectralGating(torch.nn.Module):
|
|
12 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
|
13 |
return reduced_noise
|
14 |
|
15 |
-
def predict(self,
|
16 |
-
|
17 |
-
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
18 |
-
torchaudio.save(out_path, reduced_noise, rate)
|
19 |
-
return out_path
|
20 |
|
21 |
|
22 |
|
|
|
4 |
|
5 |
|
6 |
class SpectralGating(torch.nn.Module):
|
7 |
+
def __init__(self, rate=48000):
|
8 |
super(SpectralGating, self).__init__()
|
9 |
self.rate = rate
|
10 |
|
|
|
12 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=wav, sr=self.rate))
|
13 |
return reduced_noise
|
14 |
|
15 |
+
def predict(self, wav):
|
16 |
+
return self.forward(wav)
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
|
losses.py
DELETED
@@ -1,185 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
-
# All rights reserved.
|
4 |
-
#
|
5 |
-
# This source code is licensed under the license found in the
|
6 |
-
# LICENSE file in the root directory of this source tree.
|
7 |
-
|
8 |
-
# Original copyright 2019 Tomoki Hayashi
|
9 |
-
# MIT License (https://opensource.org/licenses/MIT)
|
10 |
-
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torch.nn.functional as F
|
14 |
-
|
15 |
-
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
-
|
17 |
-
"""STFT-based Loss modules."""
|
18 |
-
|
19 |
-
|
20 |
-
def stft(x, fft_size, hop_size, win_length, window):
|
21 |
-
"""Perform STFT and convert to magnitude spectrogram.
|
22 |
-
Args:
|
23 |
-
x (Tensor): Input signal tensor (B, T).
|
24 |
-
fft_size (int): FFT size.
|
25 |
-
hop_size (int): Hop size.
|
26 |
-
win_length (int): Window length.
|
27 |
-
window (str): Window function type.
|
28 |
-
Returns:
|
29 |
-
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
30 |
-
"""
|
31 |
-
x_stft = torch.stft(x[:, 0, :], fft_size, hop_size, win_length, window, return_complex=True)
|
32 |
-
x_stft = torch.view_as_real(x_stft)
|
33 |
-
real = x_stft[..., 0]
|
34 |
-
imag = x_stft[..., 1]
|
35 |
-
|
36 |
-
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
37 |
-
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
38 |
-
|
39 |
-
|
40 |
-
class SpectralConvergengeLoss(torch.nn.Module):
|
41 |
-
"""Spectral convergence loss module."""
|
42 |
-
|
43 |
-
def __init__(self):
|
44 |
-
"""Initilize spectral convergence loss module."""
|
45 |
-
super(SpectralConvergengeLoss, self).__init__()
|
46 |
-
|
47 |
-
def forward(self, x_mag, y_mag):
|
48 |
-
"""Calculate forward propagation.
|
49 |
-
Args:
|
50 |
-
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
51 |
-
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
52 |
-
Returns:
|
53 |
-
Tensor: Spectral convergence loss value.
|
54 |
-
"""
|
55 |
-
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
56 |
-
|
57 |
-
|
58 |
-
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
59 |
-
"""Log STFT magnitude loss module."""
|
60 |
-
|
61 |
-
def __init__(self):
|
62 |
-
"""Initilize los STFT magnitude loss module."""
|
63 |
-
super(LogSTFTMagnitudeLoss, self).__init__()
|
64 |
-
|
65 |
-
def forward(self, x_mag, y_mag):
|
66 |
-
"""Calculate forward propagation.
|
67 |
-
Args:
|
68 |
-
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
69 |
-
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
70 |
-
Returns:
|
71 |
-
Tensor: Log STFT magnitude loss value.
|
72 |
-
"""
|
73 |
-
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
74 |
-
|
75 |
-
|
76 |
-
class STFTLoss(torch.nn.Module):
|
77 |
-
"""STFT loss module."""
|
78 |
-
|
79 |
-
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
80 |
-
"""Initialize STFT loss module."""
|
81 |
-
super(STFTLoss, self).__init__()
|
82 |
-
self.fft_size = fft_size
|
83 |
-
self.shift_size = shift_size
|
84 |
-
self.win_length = win_length
|
85 |
-
self.register_buffer("window", getattr(torch, window)(win_length))
|
86 |
-
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
87 |
-
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
88 |
-
|
89 |
-
def forward(self, x, y):
|
90 |
-
"""Calculate forward propagation.
|
91 |
-
Args:
|
92 |
-
x (Tensor): Predicted signal (B, T).
|
93 |
-
y (Tensor): Groundtruth signal (B, T).
|
94 |
-
Returns:
|
95 |
-
Tensor: Spectral convergence loss value.
|
96 |
-
Tensor: Log STFT magnitude loss value.
|
97 |
-
"""
|
98 |
-
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
99 |
-
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
100 |
-
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
101 |
-
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
102 |
-
|
103 |
-
return sc_loss, mag_loss
|
104 |
-
|
105 |
-
|
106 |
-
class MultiResolutionSTFTLoss(torch.nn.Module):
|
107 |
-
"""Multi resolution STFT loss module."""
|
108 |
-
|
109 |
-
def __init__(self,
|
110 |
-
fft_sizes=(1024, 2048, 512),
|
111 |
-
hop_sizes=(120, 240, 50),
|
112 |
-
win_lengths=(600, 1200, 240),
|
113 |
-
window="hann_window", factor_sc=0.1, factor_mag=0.1):
|
114 |
-
"""Initialize Multi resolution STFT loss module.
|
115 |
-
Args:
|
116 |
-
fft_sizes (list): List of FFT sizes.
|
117 |
-
hop_sizes (list): List of hop sizes.
|
118 |
-
win_lengths (list): List of window lengths.
|
119 |
-
window (str): Window function type.
|
120 |
-
factor (float): a balancing factor across different losses.
|
121 |
-
"""
|
122 |
-
super(MultiResolutionSTFTLoss, self).__init__()
|
123 |
-
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
124 |
-
self.stft_losses = torch.nn.ModuleList()
|
125 |
-
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
126 |
-
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
127 |
-
self.factor_sc = factor_sc
|
128 |
-
self.factor_mag = factor_mag
|
129 |
-
|
130 |
-
def forward(self, x, y):
|
131 |
-
"""Calculate forward propagation.
|
132 |
-
Args:
|
133 |
-
x (Tensor): Predicted signal (B, T).
|
134 |
-
y (Tensor): Groundtruth signal (B, T).
|
135 |
-
Returns:
|
136 |
-
Tensor: Multi resolution spectral convergence loss value.
|
137 |
-
Tensor: Multi resolution log STFT magnitude loss value.
|
138 |
-
"""
|
139 |
-
sc_loss = 0.0
|
140 |
-
mag_loss = 0.0
|
141 |
-
for f in self.stft_losses:
|
142 |
-
sc_l, mag_l = f(x, y)
|
143 |
-
sc_loss += sc_l
|
144 |
-
mag_loss += mag_l
|
145 |
-
sc_loss /= len(self.stft_losses)
|
146 |
-
mag_loss /= len(self.stft_losses)
|
147 |
-
|
148 |
-
return self.factor_sc*sc_loss, self.factor_mag*mag_loss
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
class L1_Multi_STFT(torch.nn.Module):
|
154 |
-
"""STFT loss module."""
|
155 |
-
|
156 |
-
def __init__(self):
|
157 |
-
"""Initialize STFT loss module."""
|
158 |
-
super(L1_Multi_STFT, self).__init__()
|
159 |
-
self.multi_STFT_loss = MultiResolutionSTFTLoss()
|
160 |
-
self.l1_loss = torch.nn.L1Loss()
|
161 |
-
|
162 |
-
def forward(self, x, y):
|
163 |
-
"""Calculate forward propagation.
|
164 |
-
Args:
|
165 |
-
x (Tensor): Predicted signal (B, T).
|
166 |
-
y (Tensor): Groundtruth signal (B, T).
|
167 |
-
Returns:
|
168 |
-
Tensor: Spectral convergence loss value.
|
169 |
-
Tensor: Log STFT magnitude loss value.
|
170 |
-
"""
|
171 |
-
sc_loss, mag_loss = self.multi_STFT_loss(x, y)
|
172 |
-
l1_loss = self.l1_loss(x, y)
|
173 |
-
return sc_loss + mag_loss + l1_loss
|
174 |
-
|
175 |
-
|
176 |
-
LOSSES = {
|
177 |
-
'mse': torch.nn.MSELoss(),
|
178 |
-
'L1': torch.nn.L1Loss(),
|
179 |
-
'Multi_STFT': MultiResolutionSTFTLoss(),
|
180 |
-
'L1_Multi_STFT': L1_Multi_STFT()
|
181 |
-
}
|
182 |
-
|
183 |
-
|
184 |
-
def get_loss(loss_config, device):
|
185 |
-
return LOSSES[loss_config['name']].to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
import hydra
|
2 |
-
from omegaconf import DictConfig
|
3 |
-
from train import train
|
4 |
-
|
5 |
-
|
6 |
-
@hydra.main(version_base=None, config_path="conf", config_name="config")
|
7 |
-
def main(cfg: DictConfig):
|
8 |
-
train(cfg)
|
9 |
-
|
10 |
-
|
11 |
-
if __name__ == '__main__':
|
12 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/EDA.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
optimizers.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
OPTIMIZERS_POOL = {
|
4 |
-
'sgd': torch.optim.SGD,
|
5 |
-
}
|
6 |
-
|
7 |
-
def get_optimizer(model_params, optimizer_config):
|
8 |
-
name, params = list(optimizer_config.items())[0]
|
9 |
-
optimizer = OPTIMIZERS_POOL[name](model_params, **params)
|
10 |
-
return optimizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/evaluation.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from tqdm import tqdm
|
3 |
-
|
4 |
-
from utils import load_wav, collect_valentini_paths
|
5 |
-
from metrics import Metrics
|
6 |
-
from denoisers.SpectralGating import SpectralGating
|
7 |
-
|
8 |
-
|
9 |
-
PARSERS = {
|
10 |
-
'valentini': collect_valentini_paths
|
11 |
-
}
|
12 |
-
MODELS = {
|
13 |
-
'baseline': SpectralGating
|
14 |
-
}
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def evaluate_on_dataset(model_name, dataset_path, dataset_type):
|
19 |
-
if model_name is not None:
|
20 |
-
model = MODELS[model_name]()
|
21 |
-
parser = PARSERS[dataset_type]
|
22 |
-
clean_wavs, noisy_wavs = parser(dataset_path)
|
23 |
-
|
24 |
-
metrics = Metrics()
|
25 |
-
mean_scores = {'PESQ': 0, 'STOI': 0}
|
26 |
-
for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)):
|
27 |
-
clean_wav = load_wav(clean_path)
|
28 |
-
noisy_wav = load_wav(noisy_path)
|
29 |
-
|
30 |
-
if model_name is None:
|
31 |
-
scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
|
32 |
-
else:
|
33 |
-
denoised_wav = model(noisy_wav)
|
34 |
-
scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)
|
35 |
-
|
36 |
-
mean_scores['PESQ'] += scores['PESQ']
|
37 |
-
mean_scores['STOI'] += scores['STOI']
|
38 |
-
|
39 |
-
mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs)
|
40 |
-
mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs)
|
41 |
-
|
42 |
-
return mean_scores
|
43 |
-
|
44 |
-
|
45 |
-
if __name__ == '__main__':
|
46 |
-
parser = argparse.ArgumentParser(prog='Program to evaluate denoising')
|
47 |
-
parser.add_argument('--dataset_path', type=str,
|
48 |
-
default='/media/public/dataset/denoising/DS_10283_2791/',
|
49 |
-
help='Path to dataset folder')
|
50 |
-
parser.add_argument('--dataset_type', type=str, required=True,
|
51 |
-
choices=['valentini'])
|
52 |
-
parser.add_argument('--model_name', type=str,
|
53 |
-
choices=['baseline'])
|
54 |
-
|
55 |
-
|
56 |
-
args = parser.parse_args()
|
57 |
-
|
58 |
-
mean_scores = evaluate_on_dataset(model_name=args.model_name,
|
59 |
-
dataset_path=args.dataset_path,
|
60 |
-
dataset_type=args.dataset_type)
|
61 |
-
print(f"Metrics on {args.dataset_type} dataset with "
|
62 |
-
f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/metrics.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
import pesq
|
2 |
-
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
3 |
-
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
4 |
-
from torchaudio.transforms import Resample
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
-
from torchmetrics import SignalNoiseRatio
|
8 |
-
|
9 |
-
class Metrics(torch.nn.Module):
|
10 |
-
def __init__(self, source_rate, target_rate=16000, *args, **kwargs):
|
11 |
-
super().__init__(*args, **kwargs)
|
12 |
-
self.source_rate = source_rate
|
13 |
-
self.target_rate = target_rate
|
14 |
-
self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate)
|
15 |
-
self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb')
|
16 |
-
self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False)
|
17 |
-
self.snr = SignalNoiseRatio()
|
18 |
-
|
19 |
-
def forward(self, denoised, clean):
|
20 |
-
pesq_scores, stoi_scores = 0, 0
|
21 |
-
for denoised_wav, clean_wav in zip(denoised, clean):
|
22 |
-
if self.source_rate != self.target_rate:
|
23 |
-
denoised_wav = self.resampler(denoised_wav)
|
24 |
-
clean_wav = self.resampler(clean_wav)
|
25 |
-
try:
|
26 |
-
pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item()
|
27 |
-
stoi_scores += self.stoi(denoised_wav, clean_wav).item()
|
28 |
-
except pesq.NoUtterancesError as e:
|
29 |
-
print(e)
|
30 |
-
except ValueError as e:
|
31 |
-
print(e)
|
32 |
-
|
33 |
-
return {'PESQ': pesq_scores,
|
34 |
-
'STOI': stoi_scores}
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testing/test_inference.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
def func(x):
|
2 |
-
return x + 1
|
3 |
-
|
4 |
-
|
5 |
-
def test_answer():
|
6 |
-
assert func(3) == 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from torch.utils.data import DataLoader
|
4 |
-
import omegaconf
|
5 |
-
from omegaconf import DictConfig
|
6 |
-
import wandb
|
7 |
-
|
8 |
-
from checkpoing_saver import CheckpointSaver
|
9 |
-
from denoisers import get_model
|
10 |
-
from optimizers import get_optimizer
|
11 |
-
from losses import get_loss
|
12 |
-
from datasets import get_datasets
|
13 |
-
from testing.metrics import Metrics
|
14 |
-
from datasets.minimal import Minimal
|
15 |
-
from tqdm import tqdm
|
16 |
-
|
17 |
-
def init_wandb(cfg):
|
18 |
-
wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
|
19 |
-
wandb.init(project=cfg['wandb']['project'],
|
20 |
-
notes=cfg['wandb']['notes'],
|
21 |
-
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
|
22 |
-
resume=cfg['wandb']['resume'],
|
23 |
-
name=cfg['wandb']['run_name'])
|
24 |
-
if wandb.run.resumed:
|
25 |
-
api = wandb.Api()
|
26 |
-
runs = api.runs(f"{cfg['wandb']['entity']}/{cfg['wandb']['project']}",
|
27 |
-
order='train_pesq')
|
28 |
-
run = [run for run in runs if run.name == cfg['wandb']['run_name'] and run.state != 'running'][0]
|
29 |
-
artifacts = run.logged_artifacts()
|
30 |
-
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
|
31 |
-
|
32 |
-
best_model.download()
|
33 |
-
|
34 |
-
def train(cfg: DictConfig):
|
35 |
-
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
36 |
-
init_wandb(cfg)
|
37 |
-
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
|
38 |
-
decreasing=False)
|
39 |
-
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
40 |
-
|
41 |
-
model = get_model(cfg['model']).to(device)
|
42 |
-
optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
|
43 |
-
loss_fn = get_loss(cfg['loss'], device)
|
44 |
-
train_dataset, valid_dataset = get_datasets(cfg)
|
45 |
-
minimal_dataset = Minimal(cfg)
|
46 |
-
|
47 |
-
dataloaders = {
|
48 |
-
'train': DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True,
|
49 |
-
num_workers=cfg['dataloader']['num_workers']),
|
50 |
-
'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=False,
|
51 |
-
num_workers=cfg['dataloader']['num_workers']),
|
52 |
-
'minimal': DataLoader(minimal_dataset)
|
53 |
-
}
|
54 |
-
|
55 |
-
wandb.watch(model, log_freq=cfg['wandb']['log_interval'])
|
56 |
-
epoch = 0
|
57 |
-
while epoch < cfg['training']['num_epochs']:
|
58 |
-
for phase in ['train', 'val']:
|
59 |
-
if phase == 'train':
|
60 |
-
model.train()
|
61 |
-
else:
|
62 |
-
model.eval()
|
63 |
-
|
64 |
-
running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
|
65 |
-
loop = tqdm(dataloaders[phase])
|
66 |
-
for i, (inputs, labels) in enumerate(loop):
|
67 |
-
inputs = inputs.to(device)
|
68 |
-
labels = labels.to(device)
|
69 |
-
|
70 |
-
optimizer.zero_grad()
|
71 |
-
|
72 |
-
with torch.set_grad_enabled(phase == 'train'):
|
73 |
-
outputs = model(inputs)
|
74 |
-
loss = loss_fn(outputs, labels)
|
75 |
-
|
76 |
-
if phase == 'train':
|
77 |
-
loss.backward()
|
78 |
-
optimizer.step()
|
79 |
-
|
80 |
-
running_metrics = metrics(denoised=outputs, clean=labels)
|
81 |
-
running_loss += loss.item() * inputs.size(0)
|
82 |
-
running_pesq += running_metrics['PESQ']
|
83 |
-
running_stoi += running_metrics['STOI']
|
84 |
-
|
85 |
-
loop.set_description(f"Epoch [{epoch}/{cfg['training']['num_epochs']}][{phase}]")
|
86 |
-
loop.set_postfix(loss=running_loss / (i + 1) / inputs.size(0),
|
87 |
-
pesq=running_pesq / (i + 1) / inputs.size(0),
|
88 |
-
stoi=running_stoi / (i + 1) / inputs.size(0))
|
89 |
-
|
90 |
-
if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
|
91 |
-
wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
|
92 |
-
"train_pesq": running_pesq / (i + 1) / inputs.size(0),
|
93 |
-
"train_stoi": running_stoi / (i + 1) / inputs.size(0)})
|
94 |
-
|
95 |
-
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
96 |
-
eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
|
97 |
-
eposh_stoi = running_stoi / len(dataloaders[phase].dataset)
|
98 |
-
|
99 |
-
wandb.log({f"{phase}_loss": epoch_loss,
|
100 |
-
f"{phase}_pesq": eposh_pesq,
|
101 |
-
f"{phase}_stoi": eposh_stoi})
|
102 |
-
|
103 |
-
if phase == 'val':
|
104 |
-
for i, (wav, rate) in enumerate(dataloaders['minimal']):
|
105 |
-
if cfg['dataloader']['normalize']:
|
106 |
-
std = torch.std(wav)
|
107 |
-
wav = wav / std
|
108 |
-
prediction = model(wav.to(device))
|
109 |
-
prediction = prediction * std
|
110 |
-
else:
|
111 |
-
prediction = model(wav.to(device))
|
112 |
-
wandb.log({
|
113 |
-
f"{i}_example": wandb.Audio(
|
114 |
-
prediction.detach().cpu().numpy()[0][0],
|
115 |
-
sample_rate=rate)})
|
116 |
-
|
117 |
-
checkpoint_saver(model, epoch, metric_val=eposh_pesq,
|
118 |
-
optimizer=optimizer, loss=epoch_loss)
|
119 |
-
epoch += 1
|
120 |
-
|
121 |
-
|
122 |
-
if __name__ == "__main__":
|
123 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transforms.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from torchaudio.transforms import Resample
|
4 |
-
from torchvision.transforms import RandomCrop
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
class Transform(torch.nn.Module):
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
input_sample_rate,
|
12 |
-
sample_rate,
|
13 |
-
max_seconds,
|
14 |
-
normalize,
|
15 |
-
*args,
|
16 |
-
**kwargs
|
17 |
-
):
|
18 |
-
super().__init__()
|
19 |
-
self.input_sample_rate = input_sample_rate
|
20 |
-
self.sample_rate = sample_rate
|
21 |
-
self.resample = Resample(orig_freq=input_sample_rate, new_freq=sample_rate)
|
22 |
-
self.random_crop = RandomCrop((1, int(max_seconds * sample_rate)), pad_if_needed=True)
|
23 |
-
self.normalize = normalize
|
24 |
-
|
25 |
-
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
26 |
-
if self.input_sample_rate != self.sample_rate:
|
27 |
-
waveform = self.resample(waveform)
|
28 |
-
if self.normalize:
|
29 |
-
waveform = waveform / torch.std(waveform)
|
30 |
-
cropped = self.random_crop(waveform)
|
31 |
-
return cropped
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
upload_to_HF.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import wandb
|
3 |
-
from huggingface_hub import HfApi
|
4 |
-
from pathlib import Path
|
5 |
-
import huggingface_hub
|
6 |
-
import ssl
|
7 |
-
import os
|
8 |
-
os.environ['CURL_CA_BUNDLE'] = ''
|
9 |
-
|
10 |
-
ssl._create_default_https_context = ssl._create_unverified_context
|
11 |
-
|
12 |
-
class Uploader:
|
13 |
-
def __init__(self, entity, project, run_name, repo_id, username):
|
14 |
-
self.entity = entity
|
15 |
-
self.project = project
|
16 |
-
self.run_name = run_name
|
17 |
-
self.hf_api = HfApi()
|
18 |
-
self.wandb_api = wandb.Api()
|
19 |
-
self.repo_id = repo_id
|
20 |
-
self.username = username
|
21 |
-
huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
|
22 |
-
|
23 |
-
def get_model_from_wandb_run(self):
|
24 |
-
runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
|
25 |
-
# order='+summary_metrics.train_pesq'
|
26 |
-
)
|
27 |
-
run = [run for run in runs if run.name == self.run_name][0]
|
28 |
-
artifacts = run.logged_artifacts()
|
29 |
-
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
|
30 |
-
artifact_dir = best_model.download()
|
31 |
-
model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
|
32 |
-
print(f"Model validation score = {best_model.metadata['Validation score']}")
|
33 |
-
return model_path
|
34 |
-
|
35 |
-
def upload_to_HF(self):
|
36 |
-
model_path = self.get_model_from_wandb_run()
|
37 |
-
self.hf_api.upload_file(
|
38 |
-
path_or_fileobj=model_path,
|
39 |
-
path_in_repo=Path(model_path).name,
|
40 |
-
repo_id=f'{self.username}/{self.repo_id}',
|
41 |
-
)
|
42 |
-
|
43 |
-
def create_repo(self):
|
44 |
-
self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
if __name__ == '__main__':
|
49 |
-
uploader = Uploader(entity='borisovmaksim',
|
50 |
-
project='denoising',
|
51 |
-
run_name='wav_normalization',
|
52 |
-
repo_id='demucs',
|
53 |
-
username='BorisovMaksim')
|
54 |
-
uploader.create_repo()
|
55 |
-
uploader.upload_to_HF()
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,7 +1,3 @@
|
|
1 |
-
import torchaudio
|
2 |
-
import torch
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
-
from pathlib import Path
|
5 |
from torch.nn.functional import pad
|
6 |
|
7 |
|
@@ -17,45 +13,3 @@ def pad_cut_batch_audio(wavs, new_shape):
|
|
17 |
return wavs
|
18 |
|
19 |
|
20 |
-
def collect_valentini_paths(dataset_path):
|
21 |
-
clean_path = Path(dataset_path) / 'clean_testset_wav'
|
22 |
-
noisy_path = Path(dataset_path) / 'noisy_testset_wav'
|
23 |
-
|
24 |
-
clean_wavs = list(clean_path.glob("*"))
|
25 |
-
noisy_wavs = list(noisy_path.glob("*"))
|
26 |
-
|
27 |
-
return clean_wavs, noisy_wavs
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
|
33 |
-
magnitude = stft.abs()
|
34 |
-
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
|
35 |
-
figure, axis = plt.subplots(1, 1)
|
36 |
-
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
|
37 |
-
figure.suptitle(title)
|
38 |
-
plt.colorbar(img, ax=axis)
|
39 |
-
plt.show()
|
40 |
-
|
41 |
-
|
42 |
-
def plot_mask(mask, title="Mask", xlim=None):
|
43 |
-
mask = mask.numpy()
|
44 |
-
figure, axis = plt.subplots(1, 1)
|
45 |
-
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
|
46 |
-
figure.suptitle(title)
|
47 |
-
plt.colorbar(img, ax=axis)
|
48 |
-
plt.show()
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
def generate_mixture(waveform_clean, waveform_noise, target_snr):
|
54 |
-
|
55 |
-
power_clean_signal = waveform_clean.pow(2).mean()
|
56 |
-
power_noise_signal = waveform_noise.pow(2).mean()
|
57 |
-
current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)
|
58 |
-
|
59 |
-
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
|
60 |
-
return waveform_clean + waveform_noise
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
from torch.nn.functional import pad
|
2 |
|
3 |
|
|
|
13 |
return wavs
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|