DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame
4.6 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adiyoss
import argparse
import json
import logging
import os
import sys
import torch
import torchaudio
from .audio import Audioset, find_audio_files
from . import distrib, pretrained
from .demucs import DemucsStreamer
from .utils import LogProgress
logger = logging.getLogger(__name__)
def add_flags(parser):
"""
Add the flags for the argument parser that are related to model loading and evaluation"
"""
pretrained.add_model_flags(parser)
parser.add_argument('--device', default="cpu")
parser.add_argument('--dry', type=float, default=0,
help='dry/wet knob coefficient. 0 is only input signal, 1 only denoised.')
parser.add_argument('--sample_rate', default=16_000, type=int, help='sample rate')
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--streaming', action="store_true",
help="true streaming evaluation for Demucs")
parser = argparse.ArgumentParser(
'denoiser.enhance',
description="Speech enhancement using Demucs - Generate enhanced files")
add_flags(parser)
parser.add_argument("--out_dir", type=str, default="enhanced",
help="directory putting enhanced wav files")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="more loggging")
group = parser.add_mutually_exclusive_group()
group.add_argument("--noisy_dir", type=str, default=None,
help="directory including noisy wav files")
group.add_argument("--noisy_json", type=str, default=None,
help="json file including noisy wav files")
def get_estimate(model, noisy, args):
torch.set_num_threads(1)
if args.streaming:
streamer = DemucsStreamer(model, dry=args.dry)
with torch.no_grad():
estimate = torch.cat([
streamer.feed(noisy[0]),
streamer.flush()], dim=1)[None]
else:
with torch.no_grad():
estimate = model(noisy)
estimate = (1 - args.dry) * estimate + args.dry * noisy
return estimate
def save_wavs(estimates, noisy_sigs, filenames, out_dir, sr=16_000):
# Write result
for estimate, noisy, filename in zip(estimates, noisy_sigs, filenames):
filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0])
write(noisy, filename + "_noisy.wav", sr=sr)
write(estimate, filename + "_enhanced.wav", sr=sr)
def write(wav, filename, sr=16_000):
# Normalize audio if it prevents clipping
wav = wav / max(wav.abs().max().item(), 1)
torchaudio.save(filename, wav.cpu(), sr)
def get_dataset(args):
if hasattr(args, 'dset'):
paths = args.dset
else:
paths = args
if paths.noisy_json:
with open(paths.noisy_json) as f:
files = json.load(f)
elif paths.noisy_dir:
files = find_audio_files(paths.noisy_dir)
else:
logger.warning(
"Small sample set was not provided by either noisy_dir or noisy_json. "
"Skipping enhancement.")
return None
return Audioset(files, with_path=True, sample_rate=args.sample_rate)
def enhance(args, model=None, local_out_dir=None):
# Load model
if not model:
model = pretrained.get_model(args).to(args.device)
model.eval()
if local_out_dir:
out_dir = local_out_dir
else:
out_dir = args.out_dir
dset = get_dataset(args)
if dset is None:
return
loader = distrib.loader(dset, batch_size=1)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
iterator = LogProgress(logger, loader, name="Generate enhanced files")
for data in iterator:
# Get batch data
noisy_signals, filenames = data
noisy_signals = noisy_signals.to(args.device)
# Forward
estimate = get_estimate(model, noisy_signals, args)
save_wavs(estimate, noisy_signals, filenames, out_dir, sr=args.sample_rate)
if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
enhance(args, local_out_dir=args.out_dir)