File size: 2,732 Bytes
33e3a91
 
 
 
4f821f0
33e3a91
 
 
 
 
 
 
 
 
 
 
 
73e61ac
 
 
33e3a91
 
73e61ac
 
 
 
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04d9b94
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c45a5
 
 
 
 
73e61ac
 
96c45a5
04d9b94
33e3a91
 
28d63d4
33e3a91
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
from tqdm import tqdm
from copy import deepcopy

import soundfile as sf
import numpy as np
import gradio as gr
import torch

import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

from util import print_size, sampling
from network import CleanUNet
import torchaudio
import torchaudio.transforms as T

SAMPLE_RATE = 22050

def load_simple(filename):
    wav, sr = torchaudio.load(filename)
    resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype)
    resampled_wav = resampler(audio)
    return resampled_wav

CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl"

# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
    data = f.read()
    config = json.loads(data)
    gen_config              = config["gen_config"]
    global network_config
    network_config          = config["network_config"]      # to define wavenet
    global train_config
    train_config            = config["train_config"]        # train config
    global trainset_config
    trainset_config         = config["trainset_config"]     # to read trainset configurations

def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"):
    """
    Denoise audio
    Parameters:
    output_directory (str):         save generated speeches to this path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                    automitically selects the maximum iteration if 'max' is selected
    subset (str):                   training, testing, validation
    dump (bool):                    whether save enhanced (denoised) audio
    """

    # setup local experiment path
    exp_path = train_config["exp_path"]
    print('exp_path:', exp_path)

    # load data
    loader_config = deepcopy(trainset_config)
    loader_config["crop_length_sec"] = 0

    # predefine model
    net = CleanUNet(**network_config)
    print_size(net)

    # load checkpoint
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    # inference
    noisy_audio = load_simple(filename)

    for batch in tqdm(noisy_audio):
        with torch.no_grad():
            generated_audio = sampling(net, batch)
            generated_audio = generated_audio.cpu()
            sf.write(out, np.ravel(generated_audio.squeeze()), SAMPLE_RATE)

    return out

audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
inputs = [audio]
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')

title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"

gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()