#!/usr/bin/env python3
import configargparse
from distutils.util import strtobool
import logging
import os
import random
import sys
import numpy as np
from espnet.asr.pytorch_backend.asr import enhance
# NOTE: you need this func to generate our sphinx doc
[docs]def get_parser():
parser = configargparse.ArgumentParser(
description='Enhance noisy speech for speech recognition',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True, help='config file path')
parser.add('--config2', is_config_file=True,
help='second config file path that overwrites the settings in `--config`.')
parser.add('--config3', is_config_file=True,
help='third config file path that overwrites the settings in `--config` and `--config2`.')
parser.add_argument('--ngpu', default=0, type=int,
help='Number of GPUs')
parser.add_argument('--backend', default='chainer', type=str,
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--debugmode', default=1, type=int,
help='Debugmode')
parser.add_argument('--seed', default=1, type=int,
help='Random seed')
parser.add_argument('--verbose', '-V', default=1, type=int,
help='Verbose option')
parser.add_argument('--batchsize', default=1, type=int,
help='Batch size for beam search (0: means no batch processing)')
parser.add_argument('--preprocess-conf', type=str, default=None,
help='The configuration file for the pre-processing')
# task related
parser.add_argument('--recog-json', type=str,
help='Filename of recognition data (json)')
# model (parameter) related
parser.add_argument('--model', type=str, required=True,
help='Model file parameters to read')
parser.add_argument('--model-conf', type=str, default=None,
help='Model config file')
# Outputs configuration
parser.add_argument('--enh-wspecifier', type=str, default=None,
help='Specify the output way for enhanced speech.'
'e.g. ark,scp:outdir,wav.scp')
parser.add_argument('--enh-filetype', type=str, default='sound',
choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
help='Specify the file format for enhanced speech. '
'"mat" is the matrix format in kaldi')
parser.add_argument('--fs', type=int, default=16000,
help='The sample frequency')
parser.add_argument('--keep-length', type=strtobool, default=True,
help='Adjust the output length to match '
'with the input for enhanced speech')
parser.add_argument('--image-dir', type=str, default=None,
help='The directory saving the images.')
parser.add_argument('--num-images', type=int, default=20,
help='The number of images files to be saved. '
'If negative, all samples are to be saved.')
# IStft
parser.add_argument('--apply-istft', type=strtobool, default=True,
help='Apply istft to the output from the network')
parser.add_argument('--istft-win-length', type=int, default=512,
help='The window length for istft. '
'This option is ignored '
'if stft is found in the preprocess-conf')
parser.add_argument('--istft-n-shift', type=str, default=256,
help='The window type for istft. '
'This option is ignored '
'if stft is found in the preprocess-conf')
parser.add_argument('--istft-window', type=str, default='hann',
help='The window type for istft. '
'This option is ignored '
'if stft is found in the preprocess-conf')
return parser
[docs]def main(args):
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
elif args.verbose == 2:
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(kamo): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info('set random seed = %d' % args.seed)
# recog
logging.info('backend = ' + args.backend)
if args.backend == "pytorch":
enhance(args)
else:
raise ValueError("Only pytorch is supported.")
if __name__ == '__main__':
main(sys.argv[1:])