Spaces:
Runtime error
Runtime error
Commit
·
bd0a813
1
Parent(s):
3f204d4
fixes
Browse files- .gitignore +1 -0
- EDA.ipynb +0 -0
- README.md +4 -4
- app.py +9 -7
- datasets.py +39 -0
- denoisers/SpectralGating.py +1 -1
- denoisers/__pycache__/SpectralGating.cpython-38.pyc +0 -0
- denoisers/demucs.py +33 -24
- evaluation.py +2 -2
- metrics.py +5 -4
- train.py +121 -0
.gitignore
CHANGED
@@ -2,3 +2,4 @@
|
|
2 |
.ipynb_checkpoints/**
|
3 |
nohup.out
|
4 |
__pycache__/**
|
|
|
|
2 |
.ipynb_checkpoints/**
|
3 |
nohup.out
|
4 |
__pycache__/**
|
5 |
+
cache_wav/
|
EDA.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -2,9 +2,9 @@
|
|
2 |
|
3 |
|
4 |
# Testing
|
5 |
-
| |
|
6 |
-
|
7 |
-
| ideal denoising |
|
8 |
-
| baseline |
|
9 |
|
10 |
|
|
|
2 |
|
3 |
|
4 |
# Testing
|
5 |
+
| | valentini_PESQ | valentini_STOI |
|
6 |
+
|:---------------:|:--------------:|:--------------:|
|
7 |
+
| ideal denoising | 1.9709 | 0.9211 |
|
8 |
+
| baseline | 1.7433 | 0.8844 |
|
9 |
|
10 |
|
app.py
CHANGED
@@ -9,35 +9,37 @@ import logging
|
|
9 |
import gradio as gr
|
10 |
|
11 |
|
|
|
12 |
from denoisers.SpectralGating import SpectralGating
|
13 |
|
14 |
model = SpectralGating()
|
15 |
|
16 |
|
17 |
def denoising_transform(audio):
|
18 |
-
src_path = "cache_wav/
|
19 |
-
tgt_path = "cache_wav/
|
20 |
-
# os.rename(audio.name, src_path)
|
21 |
(ffmpeg.input(audio)
|
22 |
.output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
|
23 |
.run()
|
24 |
)
|
25 |
-
|
26 |
-
model.predict(src_path, tgt_path)
|
27 |
return tgt_path
|
|
|
|
|
28 |
|
29 |
|
30 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
31 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
32 |
|
33 |
-
title = "
|
34 |
#"""
|
35 |
gr.Interface(
|
36 |
denoising_transform, inputs, outputs, title=title,
|
37 |
-
allow_flagging='never'
|
38 |
).launch(
|
39 |
server_name='localhost',
|
40 |
server_port=7871,
|
|
|
41 |
#ssl_keyfile='example.key',
|
42 |
#ssl_certfile="example.crt",
|
43 |
)
|
|
|
9 |
import gradio as gr
|
10 |
|
11 |
|
12 |
+
|
13 |
from denoisers.SpectralGating import SpectralGating
|
14 |
|
15 |
model = SpectralGating()
|
16 |
|
17 |
|
18 |
def denoising_transform(audio):
|
19 |
+
src_path = "cache_wav/original/{}.wav".format(str(uuid.uuid4()))
|
20 |
+
tgt_path = "cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))
|
|
|
21 |
(ffmpeg.input(audio)
|
22 |
.output(src_path, acodec='pcm_s16le', ac=1, ar=22050)
|
23 |
.run()
|
24 |
)
|
25 |
+
model.predict(audio, tgt_path)
|
|
|
26 |
return tgt_path
|
27 |
+
# model.predict(src_path, tgt_path)
|
28 |
+
# return tgt_path
|
29 |
|
30 |
|
31 |
inputs = gr.inputs.Audio(label="Source Audio", source="microphone", type='filepath')
|
32 |
outputs = gr.outputs.Audio(label="Target Audio", type='filepath')
|
33 |
|
34 |
+
title = "Denoising"
|
35 |
#"""
|
36 |
gr.Interface(
|
37 |
denoising_transform, inputs, outputs, title=title,
|
38 |
+
allow_flagging='never'
|
39 |
).launch(
|
40 |
server_name='localhost',
|
41 |
server_port=7871,
|
42 |
+
share=True
|
43 |
#ssl_keyfile='example.key',
|
44 |
#ssl_certfile="example.crt",
|
45 |
)
|
datasets.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from pathlib import Path
|
4 |
+
from utils import load_wav
|
5 |
+
|
6 |
+
|
7 |
+
class Valentini(Dataset):
|
8 |
+
def __init__(self, dataset_path='/media/public/datasets/denoising/DS_10283_2791/', transform=None,
|
9 |
+
valid=False):
|
10 |
+
clean_path = Path(dataset_path) / 'clean_trainset_56spk_wav'
|
11 |
+
noisy_path = Path(dataset_path) / 'noisy_trainset_56spk_wav'
|
12 |
+
clean_wavs = list(clean_path.glob("*"))
|
13 |
+
noisy_wavs = list(noisy_path.glob("*"))
|
14 |
+
valid_threshold = int(len(clean_wavs) * 0.2)
|
15 |
+
if valid:
|
16 |
+
self.clean_wavs = clean_wavs[:valid_threshold]
|
17 |
+
self.noisy_wavs = noisy_wavs[:valid_threshold]
|
18 |
+
else:
|
19 |
+
self.clean_wavs = clean_wavs[valid_threshold:]
|
20 |
+
self.noisy_wavs = noisy_wavs[valid_threshold:]
|
21 |
+
|
22 |
+
assert len(self.clean_wavs) == len(self.noisy_wavs)
|
23 |
+
|
24 |
+
self.transform = transform
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.clean_wavs)
|
28 |
+
|
29 |
+
def __getitem__(self, idx):
|
30 |
+
noisy_wav = load_wav(self.noisy_wavs[idx])
|
31 |
+
clean_wav = load_wav(self.clean_wavs[idx])
|
32 |
+
|
33 |
+
if self.transform:
|
34 |
+
random_seed = torch.randint(100, (1,))[0]
|
35 |
+
torch.manual_seed(random_seed)
|
36 |
+
noisy_wav = self.transform(noisy_wav)
|
37 |
+
torch.manual_seed(random_seed)
|
38 |
+
clean_wav = self.transform(clean_wav)
|
39 |
+
return noisy_wav, clean_wav
|
denoisers/SpectralGating.py
CHANGED
@@ -16,7 +16,7 @@ class SpectralGating(torch.nn.Module):
|
|
16 |
data, rate = torchaudio.load(wav_path)
|
17 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
18 |
torchaudio.save(out_path, reduced_noise, rate)
|
19 |
-
return
|
20 |
|
21 |
|
22 |
|
|
|
16 |
data, rate = torchaudio.load(wav_path)
|
17 |
reduced_noise = torch.Tensor(nr.reduce_noise(y=data, sr=rate))
|
18 |
torchaudio.save(out_path, reduced_noise, rate)
|
19 |
+
return out_path
|
20 |
|
21 |
|
22 |
|
denoisers/__pycache__/SpectralGating.cpython-38.pyc
ADDED
Binary file (1.08 kB). View file
|
|
denoisers/demucs.py
CHANGED
@@ -1,36 +1,34 @@
|
|
1 |
import torch
|
2 |
-
|
3 |
|
4 |
class Encoder(torch.nn.Module):
|
5 |
-
def __init__(self, in_channels, out_channels
|
6 |
-
kernel_size_1=8, stride_1=4,
|
7 |
-
kernel_size_2=1, stride_2=1):
|
8 |
super(Encoder, self).__init__()
|
9 |
|
10 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
11 |
-
kernel_size=
|
12 |
self.relu1 = torch.nn.ReLU()
|
13 |
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
14 |
-
kernel_size=
|
15 |
-
self.glu = torch.nn.GLU()
|
16 |
|
17 |
def forward(self, x):
|
18 |
x = self.relu1(self.conv1(x))
|
|
|
|
|
19 |
x = self.glu(self.conv2(x))
|
20 |
return x
|
21 |
|
22 |
|
23 |
class Decoder(torch.nn.Module):
|
24 |
-
def __init__(self, in_channels, out_channels
|
25 |
-
kernel_size_1=3, stride_1=1,
|
26 |
-
kernel_size_2=8, stride_2=4):
|
27 |
super(Decoder, self).__init__()
|
28 |
|
29 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
30 |
-
kernel_size=
|
31 |
-
self.glu = torch.nn.GLU()
|
32 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
33 |
-
kernel_size=
|
34 |
self.relu = torch.nn.ReLU()
|
35 |
|
36 |
def forward(self, x):
|
@@ -40,28 +38,39 @@ class Decoder(torch.nn.Module):
|
|
40 |
|
41 |
|
42 |
class Demucs(torch.nn.Module):
|
43 |
-
def __init__(self):
|
44 |
super(Demucs, self).__init__()
|
45 |
|
46 |
-
self.encoder1 = Encoder(in_channels=1, out_channels=
|
47 |
-
self.encoder2 = Encoder(in_channels=
|
48 |
-
self.encoder3 = Encoder(in_channels=
|
49 |
|
50 |
-
self.lstm = torch.nn.LSTM(
|
|
|
|
|
51 |
|
52 |
-
self.decoder1 = Decoder(in_channels=
|
53 |
-
self.decoder2 = Decoder(in_channels=
|
54 |
-
self.decoder3 = Decoder(in_channels=
|
55 |
|
56 |
def forward(self, x):
|
57 |
out1 = self.encoder1(x)
|
58 |
out2 = self.encoder2(out1)
|
59 |
out3 = self.encoder3(out2)
|
60 |
|
61 |
-
x = self.lstm(out3)
|
62 |
-
|
63 |
x = self.decoder1(x + out3)
|
|
|
64 |
x = self.decoder2(x + out2)
|
65 |
-
x =
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
67 |
return x
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from torch.nn.functional import pad
|
3 |
|
4 |
class Encoder(torch.nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels):
|
|
|
|
|
6 |
super(Encoder, self).__init__()
|
7 |
|
8 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
|
9 |
+
kernel_size=8, stride=2)
|
10 |
self.relu1 = torch.nn.ReLU()
|
11 |
self.conv2 = torch.nn.Conv1d(in_channels=out_channels, out_channels=2 * out_channels,
|
12 |
+
kernel_size=1, stride=1)
|
13 |
+
self.glu = torch.nn.GLU(dim=-2)
|
14 |
|
15 |
def forward(self, x):
|
16 |
x = self.relu1(self.conv1(x))
|
17 |
+
if x.shape[-1] % 2 == 1:
|
18 |
+
x = pad(x, (0, 1))
|
19 |
x = self.glu(self.conv2(x))
|
20 |
return x
|
21 |
|
22 |
|
23 |
class Decoder(torch.nn.Module):
|
24 |
+
def __init__(self, in_channels, out_channels):
|
|
|
|
|
25 |
super(Decoder, self).__init__()
|
26 |
|
27 |
self.conv1 = torch.nn.Conv1d(in_channels=in_channels, out_channels=2 * in_channels,
|
28 |
+
kernel_size=1, stride=1)
|
29 |
+
self.glu = torch.nn.GLU(dim=-2)
|
30 |
self.conv2 = torch.nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels,
|
31 |
+
kernel_size=8, stride=2)
|
32 |
self.relu = torch.nn.ReLU()
|
33 |
|
34 |
def forward(self, x):
|
|
|
38 |
|
39 |
|
40 |
class Demucs(torch.nn.Module):
|
41 |
+
def __init__(self, H):
|
42 |
super(Demucs, self).__init__()
|
43 |
|
44 |
+
self.encoder1 = Encoder(in_channels=1, out_channels=H)
|
45 |
+
self.encoder2 = Encoder(in_channels=H, out_channels=2*H)
|
46 |
+
self.encoder3 = Encoder(in_channels=2*H, out_channels=4*H)
|
47 |
|
48 |
+
self.lstm = torch.nn.LSTM(
|
49 |
+
input_size=4*H,
|
50 |
+
hidden_size=4*H, num_layers=2, batch_first=True)
|
51 |
|
52 |
+
self.decoder1 = Decoder(in_channels=4*H, out_channels=2*H)
|
53 |
+
self.decoder2 = Decoder(in_channels=2*H, out_channels=H)
|
54 |
+
self.decoder3 = Decoder(in_channels=H, out_channels=1)
|
55 |
|
56 |
def forward(self, x):
|
57 |
out1 = self.encoder1(x)
|
58 |
out2 = self.encoder2(out1)
|
59 |
out3 = self.encoder3(out2)
|
60 |
|
61 |
+
x, _ = self.lstm(out3.permute(0, 2, 1))
|
62 |
+
x = x.permute(0, 2, 1)
|
63 |
x = self.decoder1(x + out3)
|
64 |
+
x = x[:, :, :out2.shape[-1]]
|
65 |
x = self.decoder2(x + out2)
|
66 |
+
x = x[:, :, :-1]
|
67 |
+
out1 = out1[:, :, :-1]
|
68 |
+
if x.shape[-1] > out1.shape[-1]:
|
69 |
+
x = x[:, :, :out1.shape[-1]]
|
70 |
+
elif x.shape[-1] < out1.shape[-1]:
|
71 |
+
out1 = out1[:, :, :x.shape[-1]]
|
72 |
|
73 |
+
x = self.decoder3(x + out1)
|
74 |
return x
|
75 |
+
|
76 |
+
|
evaluation.py
CHANGED
@@ -28,10 +28,10 @@ def evaluate_on_dataset(model_name, dataset_path, dataset_type):
|
|
28 |
noisy_wav = load_wav(noisy_path)
|
29 |
|
30 |
if model_name is None:
|
31 |
-
scores = metrics.calculate(noisy_wav, clean_wav)
|
32 |
else:
|
33 |
denoised_wav = model(noisy_wav)
|
34 |
-
scores = metrics.calculate(
|
35 |
|
36 |
mean_scores['PESQ'] += scores['PESQ']
|
37 |
mean_scores['STOI'] += scores['STOI']
|
|
|
28 |
noisy_wav = load_wav(noisy_path)
|
29 |
|
30 |
if model_name is None:
|
31 |
+
scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav)
|
32 |
else:
|
33 |
denoised_wav = model(noisy_wav)
|
34 |
+
scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav)
|
35 |
|
36 |
mean_scores['PESQ'] += scores['PESQ']
|
37 |
mean_scores['STOI'] += scores['STOI']
|
metrics.py
CHANGED
@@ -2,16 +2,17 @@ from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
|
2 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
-
|
6 |
|
7 |
|
8 |
class Metrics:
|
9 |
def __init__(self, rate=16000):
|
10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
11 |
self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
|
|
|
12 |
|
13 |
-
def calculate(self,
|
14 |
-
return {'PESQ': self.nb_pesq(
|
15 |
-
'STOI': self.stoi(
|
16 |
|
17 |
|
|
|
2 |
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
3 |
import torch
|
4 |
import torchaudio
|
5 |
+
from torchmetrics import SignalNoiseRatio
|
6 |
|
7 |
|
8 |
class Metrics:
|
9 |
def __init__(self, rate=16000):
|
10 |
self.nb_pesq = PerceptualEvaluationSpeechQuality(rate, 'wb')
|
11 |
self.stoi = ShortTimeObjectiveIntelligibility(rate, False)
|
12 |
+
self.snr = SignalNoiseRatio()
|
13 |
|
14 |
+
def calculate(self, denoised, clean):
|
15 |
+
return {'PESQ': self.nb_pesq(denoised, clean),
|
16 |
+
'STOI': self.stoi(denoised, clean)}
|
17 |
|
18 |
|
train.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.tensorboard import SummaryWriter
|
3 |
+
import torch
|
4 |
+
from torch.nn import Sequential
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from datasets import Valentini
|
7 |
+
from datetime import datetime
|
8 |
+
from torchvision.transforms import RandomCrop
|
9 |
+
from utils import load_wav
|
10 |
+
from denoisers.demucs import Demucs
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
14 |
+
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
model = Demucs(H=64).to(device)
|
18 |
+
|
19 |
+
DATASET_PATH = Path('/media/public/datasets/denoising/DS_10283_2791/')
|
20 |
+
VALID_WAVS = {'hard': 'p257_171.wav',
|
21 |
+
'medium': 'p232_071.wav',
|
22 |
+
'easy': 'p232_284.wav'}
|
23 |
+
MAX_SECONDS = 3.2
|
24 |
+
SAMPLE_RATE = 16000
|
25 |
+
|
26 |
+
transform = Sequential(RandomCrop((1, int(MAX_SECONDS * SAMPLE_RATE)), pad_if_needed=True))
|
27 |
+
|
28 |
+
training_loader = DataLoader(Valentini(valid=False, transform=transform), batch_size=12, shuffle=True)
|
29 |
+
validation_loader = DataLoader(Valentini(valid=True, transform=transform), batch_size=12, shuffle=True)
|
30 |
+
|
31 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
32 |
+
loss_fn = torch.nn.MSELoss()
|
33 |
+
|
34 |
+
|
35 |
+
def train_one_epoch(epoch_index, tb_writer):
|
36 |
+
running_loss = 0.
|
37 |
+
last_loss = 0.
|
38 |
+
|
39 |
+
for i, data in enumerate(training_loader):
|
40 |
+
inputs, labels = data
|
41 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
42 |
+
|
43 |
+
optimizer.zero_grad()
|
44 |
+
|
45 |
+
outputs = model(inputs)
|
46 |
+
|
47 |
+
loss = loss_fn(outputs, labels)
|
48 |
+
loss.backward()
|
49 |
+
|
50 |
+
optimizer.step()
|
51 |
+
|
52 |
+
running_loss += loss.item()
|
53 |
+
if i % 1000 == 999:
|
54 |
+
last_loss = running_loss / 100 # loss per batch
|
55 |
+
print(' batch {} loss: {}'.format(i + 1, last_loss))
|
56 |
+
tb_x = epoch_index * len(training_loader) + i + 1
|
57 |
+
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
|
58 |
+
running_loss = 0.
|
59 |
+
|
60 |
+
|
61 |
+
return last_loss
|
62 |
+
|
63 |
+
|
64 |
+
def train():
|
65 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
66 |
+
writer = SummaryWriter('runs/denoising_trainer_{}'.format(timestamp))
|
67 |
+
epoch_number = 0
|
68 |
+
|
69 |
+
EPOCHS = 5
|
70 |
+
|
71 |
+
best_vloss = 1_000_000.
|
72 |
+
|
73 |
+
for tag, wav_path in VALID_WAVS.items():
|
74 |
+
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
|
75 |
+
writer.add_audio(tag=tag, snd_tensor=wav, sample_rate=SAMPLE_RATE)
|
76 |
+
writer.flush()
|
77 |
+
|
78 |
+
for epoch in range(EPOCHS):
|
79 |
+
print('EPOCH {}:'.format(epoch_number + 1))
|
80 |
+
|
81 |
+
# Make sure gradient tracking is on, and do a pass over the data
|
82 |
+
model.train(True)
|
83 |
+
avg_loss = train_one_epoch(epoch_number, writer)
|
84 |
+
|
85 |
+
# We don't need gradients on to do reporting
|
86 |
+
model.train(False)
|
87 |
+
|
88 |
+
running_vloss = 0.0
|
89 |
+
with torch.no_grad():
|
90 |
+
for i, vdata in enumerate(validation_loader):
|
91 |
+
vinputs, vlabels = vdata
|
92 |
+
vinputs, vlabels = vinputs.to(device), vlabels.to(device)
|
93 |
+
voutputs = model(vinputs)
|
94 |
+
vloss = loss_fn(voutputs, vlabels)
|
95 |
+
running_vloss += vloss
|
96 |
+
|
97 |
+
avg_vloss = running_vloss / (i + 1)
|
98 |
+
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
|
99 |
+
|
100 |
+
writer.add_scalars('Training vs. Validation Loss',
|
101 |
+
{'Training': avg_loss, 'Validation': avg_vloss},
|
102 |
+
epoch_number + 1)
|
103 |
+
for tag, wav_path in VALID_WAVS.items():
|
104 |
+
wav = load_wav(DATASET_PATH / 'noisy_testset_wav' / wav_path)
|
105 |
+
wav = torch.reshape(wav, (1, 1, -1)).to(device)
|
106 |
+
prediction = model(wav)
|
107 |
+
writer.add_audio(tag=f"Model predicted {tag} on epoch {epoch}",
|
108 |
+
snd_tensor=prediction,
|
109 |
+
sample_rate=SAMPLE_RATE)
|
110 |
+
writer.flush()
|
111 |
+
|
112 |
+
if avg_vloss < best_vloss:
|
113 |
+
best_vloss = avg_vloss
|
114 |
+
model_path = 'checkpoints/model_{}_{}'.format(timestamp, epoch_number)
|
115 |
+
torch.save(model.state_dict(), model_path)
|
116 |
+
|
117 |
+
epoch_number += 1
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
train()
|