Spaces:
Runtime error
Runtime error
import os | |
import json | |
from tqdm import tqdm | |
from copy import deepcopy | |
import numpy as np | |
import gradio as gr | |
import torch | |
import random | |
random.seed(0) | |
torch.manual_seed(0) | |
np.random.seed(0) | |
from scipy.io.wavfile import write as wavwrite | |
from util import print_size, sampling | |
from network import CleanUNet | |
import torchaudio | |
def load_simple(filename): | |
print(filename) | |
audio, _ = torchaudio.load(filename) | |
return audio | |
CONFIG = "configs/DNS-large-full.json" | |
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" | |
# Parse configs. Globals nicer in this case | |
with open(CONFIG) as f: | |
data = f.read() | |
config = json.loads(data) | |
gen_config = config["gen_config"] | |
global network_config | |
network_config = config["network_config"] # to define wavenet | |
global train_config | |
train_config = config["train_config"] # train config | |
global trainset_config | |
trainset_config = config["trainset_config"] # to read trainset configurations | |
def denoise(filename, ckpt_path = CHECKPOINT): | |
""" | |
Denoise audio | |
Parameters: | |
output_directory (str): save generated speeches to this path | |
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
automitically selects the maximum iteration if 'max' is selected | |
subset (str): training, testing, validation | |
dump (bool): whether save enhanced (denoised) audio | |
""" | |
# setup local experiment path | |
exp_path = train_config["exp_path"] | |
print('exp_path:', exp_path) | |
# load data | |
loader_config = deepcopy(trainset_config) | |
loader_config["crop_length_sec"] = 0 | |
# predefine model | |
net = CleanUNet(**network_config) | |
print_size(net) | |
# load checkpoint | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
net.load_state_dict(checkpoint['model_state_dict']) | |
net.eval() | |
# inference | |
batch_size = 1000000 | |
new_file_name = filename + "_denoised.wav" | |
noisy_audio = load_simple(filename) | |
LENGTH = len(noisy_audio[0].squeeze()) | |
noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1) | |
all_audio = [] | |
for batch in tqdm(noisy_audio): | |
with torch.no_grad(): | |
generated_audio = sampling(net, batch) | |
generated_audio = generated_audio.cpu().numpy().squeeze() | |
all_audio.append(generated_audio) | |
all_audio = np.concatenate(all_audio, axis=0) | |
print("saved to:", new_file_name) | |
wavwrite(new_file_name, 32000, all_audio.squeeze()) | |
return new_file_name | |
audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') | |
inputs = [audio] | |
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') | |
title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" | |
gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |