DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame contribute delete
No virus
4.36 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adiyoss
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
import logging
import sys
from pesq import pesq
from pystoi import stoi
import torch
from .data import NoisyCleanSet
from .enhance import add_flags, get_estimate
from . import distrib, pretrained
from .utils import bold, LogProgress
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(
'denoiser.evaluate',
description='Speech enhancement using Demucs - Evaluate model performance')
add_flags(parser)
parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files')
parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.')
parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True,
help="Don't compute PESQ.")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="More loggging")
def evaluate(args, model=None, data_loader=None):
total_pesq = 0
total_stoi = 0
total_cnt = 0
updates = 5
# Load model
if not model:
model = pretrained.get_model(args).to(args.device)
model.eval()
# Load data
if data_loader is None:
dataset = NoisyCleanSet(args.data_dir, matching=args.matching, sample_rate=args.sample_rate)
data_loader = distrib.loader(dataset, batch_size=1, num_workers=2)
pendings = []
with ProcessPoolExecutor(args.num_workers) as pool:
with torch.no_grad():
iterator = LogProgress(logger, data_loader, name="Eval estimates")
for i, data in enumerate(iterator):
# Get batch data
noisy, clean = [x.to(args.device) for x in data]
# If device is CPU, we do parallel evaluation in each CPU worker.
if args.device == 'cpu':
pendings.append(
pool.submit(_estimate_and_run_metrics, clean, model, noisy, args))
else:
estimate = get_estimate(model, noisy, args)
estimate = estimate.cpu()
clean = clean.cpu()
pendings.append(
pool.submit(_run_metrics, clean, estimate, args))
total_cnt += clean.shape[0]
for pending in LogProgress(logger, pendings, updates, name="Eval metrics"):
pesq_i, stoi_i = pending.result()
total_pesq += pesq_i
total_stoi += stoi_i
metrics = [total_pesq, total_stoi]
pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt)
logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.'))
return pesq, stoi
def _estimate_and_run_metrics(clean, model, noisy, args):
estimate = get_estimate(model, noisy, args)
return _run_metrics(clean, estimate, args)
def _run_metrics(clean, estimate, args):
estimate = estimate.numpy()[:, 0]
clean = clean.numpy()[:, 0]
if args.pesq:
pesq_i = get_pesq(clean, estimate, sr=args.sample_rate)
else:
pesq_i = 0
stoi_i = get_stoi(clean, estimate, sr=args.sample_rate)
return pesq_i, stoi_i
def get_pesq(ref_sig, out_sig, sr):
"""Calculate PESQ.
Args:
ref_sig: numpy.ndarray, [B, T]
out_sig: numpy.ndarray, [B, T]
Returns:
PESQ
"""
pesq_val = 0
for i in range(len(ref_sig)):
pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb')
return pesq_val
def get_stoi(ref_sig, out_sig, sr):
"""Calculate STOI.
Args:
ref_sig: numpy.ndarray, [B, T]
out_sig: numpy.ndarray, [B, T]
Returns:
STOI
"""
stoi_val = 0
for i in range(len(ref_sig)):
stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False)
return stoi_val
def main():
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
pesq, stoi = evaluate(args)
json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout)
sys.stdout.write('\n')
if __name__ == '__main__':
main()