|
|
|
import configargparse |
|
from distutils.util import strtobool |
|
import logging |
|
import os |
|
import random |
|
import sys |
|
|
|
import numpy as np |
|
|
|
from espnet.asr.pytorch_backend.asr import enhance |
|
|
|
|
|
|
|
def get_parser(): |
|
parser = configargparse.ArgumentParser( |
|
description="Enhance noisy speech for speech recognition", |
|
config_file_parser_class=configargparse.YAMLConfigFileParser, |
|
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
parser.add("--config", is_config_file=True, help="config file path") |
|
parser.add( |
|
"--config2", |
|
is_config_file=True, |
|
help="second config file path that overwrites the settings in `--config`.", |
|
) |
|
parser.add( |
|
"--config3", |
|
is_config_file=True, |
|
help="third config file path that overwrites the settings " |
|
"in `--config` and `--config2`.", |
|
) |
|
|
|
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs") |
|
parser.add_argument( |
|
"--backend", |
|
default="chainer", |
|
type=str, |
|
choices=["chainer", "pytorch"], |
|
help="Backend library", |
|
) |
|
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") |
|
parser.add_argument("--seed", default=1, type=int, help="Random seed") |
|
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option") |
|
parser.add_argument( |
|
"--batchsize", |
|
default=1, |
|
type=int, |
|
help="Batch size for beam search (0: means no batch processing)", |
|
) |
|
parser.add_argument( |
|
"--preprocess-conf", |
|
type=str, |
|
default=None, |
|
help="The configuration file for the pre-processing", |
|
) |
|
|
|
parser.add_argument( |
|
"--recog-json", type=str, help="Filename of recognition data (json)" |
|
) |
|
|
|
parser.add_argument( |
|
"--model", type=str, required=True, help="Model file parameters to read" |
|
) |
|
parser.add_argument( |
|
"--model-conf", type=str, default=None, help="Model config file" |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--enh-wspecifier", |
|
type=str, |
|
default=None, |
|
help="Specify the output way for enhanced speech." |
|
"e.g. ark,scp:outdir,wav.scp", |
|
) |
|
parser.add_argument( |
|
"--enh-filetype", |
|
type=str, |
|
default="sound", |
|
choices=["mat", "hdf5", "sound.hdf5", "sound"], |
|
help="Specify the file format for enhanced speech. " |
|
'"mat" is the matrix format in kaldi', |
|
) |
|
parser.add_argument("--fs", type=int, default=16000, help="The sample frequency") |
|
parser.add_argument( |
|
"--keep-length", |
|
type=strtobool, |
|
default=True, |
|
help="Adjust the output length to match " "with the input for enhanced speech", |
|
) |
|
parser.add_argument( |
|
"--image-dir", type=str, default=None, help="The directory saving the images." |
|
) |
|
parser.add_argument( |
|
"--num-images", |
|
type=int, |
|
default=20, |
|
help="The number of images files to be saved. " |
|
"If negative, all samples are to be saved.", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--apply-istft", |
|
type=strtobool, |
|
default=True, |
|
help="Apply istft to the output from the network", |
|
) |
|
parser.add_argument( |
|
"--istft-win-length", |
|
type=int, |
|
default=512, |
|
help="The window length for istft. " |
|
"This option is ignored " |
|
"if stft is found in the preprocess-conf", |
|
) |
|
parser.add_argument( |
|
"--istft-n-shift", |
|
type=str, |
|
default=256, |
|
help="The window type for istft. " |
|
"This option is ignored " |
|
"if stft is found in the preprocess-conf", |
|
) |
|
parser.add_argument( |
|
"--istft-window", |
|
type=str, |
|
default="hann", |
|
help="The window type for istft. " |
|
"This option is ignored " |
|
"if stft is found in the preprocess-conf", |
|
) |
|
return parser |
|
|
|
|
|
def main(args): |
|
parser = get_parser() |
|
args = parser.parse_args(args) |
|
|
|
|
|
if args.verbose == 1: |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
elif args.verbose == 2: |
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
else: |
|
logging.basicConfig( |
|
level=logging.WARN, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
logging.warning("Skip DEBUG/INFO messages") |
|
|
|
|
|
if args.ngpu > 0: |
|
cvd = os.environ.get("CUDA_VISIBLE_DEVICES") |
|
if cvd is None: |
|
logging.warning("CUDA_VISIBLE_DEVICES is not set.") |
|
elif args.ngpu != len(cvd.split(",")): |
|
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") |
|
sys.exit(1) |
|
|
|
|
|
if args.ngpu > 1: |
|
logging.error("The program only supports ngpu=1.") |
|
sys.exit(1) |
|
|
|
|
|
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) |
|
|
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
logging.info("set random seed = %d" % args.seed) |
|
|
|
|
|
logging.info("backend = " + args.backend) |
|
if args.backend == "pytorch": |
|
enhance(args) |
|
else: |
|
raise ValueError("Only pytorch is supported.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1:]) |
|
|