Source code for espnet.asr.pytorch_backend.asr

#!/usr/bin/env python

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


import copy
import json
import logging
import math
import os
import sys

from chainer.datasets import TransformDataset
from chainer import reporter as reporter_module
from chainer import training
from chainer.training import extensions
from chainer.training.updater import StandardUpdater
import numpy as np
from tensorboardX import SummaryWriter
import torch
from torch.nn.parallel import data_parallel

from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import plot_spectrogram
from espnet.asr.asr_utils import restore_snapshot
from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.asr.pytorch_backend.asr_init import load_trained_modules
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.e2e_asr import pad_list
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
from espnet.transform.spectrogram import IStft
from espnet.transform.transformation import Transformation
from espnet.utils.cli_writers import file_writer_helper
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator
from espnet.utils.training.iterators import ToggleableShufflingSerialIterator
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop

import matplotlib
matplotlib.use('Agg')

if sys.version_info[0] == 2:
    from itertools import izip_longest as zip_longest
else:
    from itertools import zip_longest as zip_longest


[docs]class CustomEvaluator(BaseEvaluator): """Custom Evaluator for Pytorch. Args: model (torch.nn.Module): The model to evaluate. iterator (chainer.dataset.Iterator) : The train iterator. target (link | dict[str, link]) :Link object or a dictionary of links to evaluate. If this is just a link object, the link is registered by the name ``'main'``. converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter function to build input arrays. Each batch extracted by the main iterator and the ``device`` option are passed to this function. :func:`chainer.dataset.concat_examples` is used by default. device (torch.device): The device used. ngpu (int): The number of GPUs. """ def __init__(self, model, iterator, target, converter, device, ngpu=None): super(CustomEvaluator, self).__init__(iterator, target) self.model = model self.converter = converter self.device = device if ngpu is not None: self.ngpu = ngpu elif device.type == "cpu": self.ngpu = 0 else: self.ngpu = 1 # The core part of the update routine can be customized by overriding
[docs] def evaluate(self): """Main evaluate routine for CustomEvaluator.""" iterator = self._iterators['main'] if self.eval_hook: self.eval_hook(self) if hasattr(iterator, 'reset'): iterator.reset() it = iterator else: it = copy.copy(iterator) summary = reporter_module.DictSummary() self.model.eval() with torch.no_grad(): for batch in it: observation = {} with reporter_module.report_scope(observation): # read scp files # x: original json with loaded features # will be converted to chainer variable later x = self.converter(batch, self.device) if self.ngpu == 0: self.model(*x) else: # apex does not support torch.nn.DataParallel data_parallel(self.model, x, range(self.ngpu)) summary.add(observation) self.model.train() return summary.compute_mean()
[docs]class CustomUpdater(StandardUpdater): """Custom Updater for Pytorch. Args: model (torch.nn.Module): The model to update. grad_clip_threshold (float): The gradient clipping value to use. train_iter (chainer.dataset.Iterator): The training iterator. optimizer (torch.optim.optimizer): The training optimizer. converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter function to build input arrays. Each batch extracted by the main iterator and the ``device`` option are passed to this function. :func:`chainer.dataset.concat_examples` is used by default. device (torch.device): The device to use. ngpu (int): The number of gpus to use. use_apex (bool): The flag to use Apex in backprop. """ def __init__(self, model, grad_clip_threshold, train_iter, optimizer, converter, device, ngpu, grad_noise=False, accum_grad=1, use_apex=False): super(CustomUpdater, self).__init__(train_iter, optimizer) self.model = model self.grad_clip_threshold = grad_clip_threshold self.converter = converter self.device = device self.ngpu = ngpu self.accum_grad = accum_grad self.forward_count = 0 self.grad_noise = grad_noise self.iteration = 0 self.use_apex = use_apex # The core part of the update routine can be customized by overriding.
[docs] def update_core(self): """Main update routine of the CustomUpdater.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') optimizer = self.get_optimizer('main') # Get the next batch ( a list of json files) batch = train_iter.next() self.iteration += 1 x = self.converter(batch, self.device) # Compute the loss at this time step and accumulate it if self.ngpu == 0: loss = self.model(*x).mean() / self.accum_grad else: # apex does not support torch.nn.DataParallel loss = data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad if self.use_apex: from apex import amp # NOTE: for a compatibility with noam optimizer opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() else: loss.backward() # gradient noise injection if self.grad_noise: from espnet.asr.asr_utils import add_gradient_noise add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55) loss.detach() # Truncate the graph # update parameters self.forward_count += 1 if self.forward_count != self.accum_grad: return self.forward_count = 0 # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.grad_clip_threshold) logging.info('grad norm={}'.format(grad_norm)) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() optimizer.zero_grad()
[docs] def update(self): self.update_core() # #iterations with accum_grad > 1 # Ref.: https://github.com/espnet/espnet/issues/777 if self.forward_count == 0: self.iteration += 1
[docs]class CustomConverter(object): """Custom batch converter for Pytorch. Args: subsampling_factor (int): The subsampling factor. dtype (torch.dtype): Data type to convert. """ def __init__(self, subsampling_factor=1, dtype=torch.float32): self.subsampling_factor = subsampling_factor self.ignore_id = -1 self.dtype = dtype def __call__(self, batch, device): """Transforms a batch and send it to a device. Args: batch (list): The batch to transform. device (torch.device): The device to send to. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor) """ # batch should be located in list assert len(batch) == 1 xs, ys = batch[0] # perform subsampling if self.subsampling_factor > 1: xs = [x[::self.subsampling_factor, :] for x in xs] # get batch of lengths of input sequences ilens = np.array([x.shape[0] for x in xs]) # perform padding and convert to tensor # currently only support real number if xs[0].dtype.kind == 'c': xs_pad_real = pad_list( [torch.from_numpy(x.real).float() for x in xs], 0).to(device, dtype=self.dtype) xs_pad_imag = pad_list( [torch.from_numpy(x.imag).float() for x in xs], 0).to(device, dtype=self.dtype) # Note(kamo): # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. # Don't create ComplexTensor and give it E2E here # because torch.nn.DataParellel can't handle it. xs_pad = {'real': xs_pad_real, 'imag': xs_pad_imag} else: xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device, dtype=self.dtype) ilens = torch.from_numpy(ilens).to(device) # NOTE: this is for multi-task learning (e.g., speech translation) ys_pad = pad_list([torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long() for y in ys], self.ignore_id).to(device) return xs_pad, ilens, ys_pad
[docs]def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning('cuda is not available') # get input and output dimension info with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]['input'][0]['shape'][-1]) odim = int(valid_json[utts[0]]['output'][0]['shape'][-1]) logging.info('#input dims : ' + str(idim)) logging.info('#output dims: ' + str(odim)) # specify attention, CTC, hybrid mode if args.mtlalpha == 1.0: mtl_mode = 'ctc' logging.info('Pure CTC mode') elif args.mtlalpha == 0.0: mtl_mode = 'att' logging.info('Pure attention mode') else: mtl_mode = 'mtl' logging.info('Multitask learning mode') if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, ASRInterface) subsampling_factor = model.subsample[0] if args.rnnlm is not None: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch.load(args.rnnlm, rnnlm) model.rnnlm = rnnlm # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + '/model.json' with open(model_conf, 'wb') as f: logging.info('writing a model config file to ' + model_conf) f.write(json.dumps((idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8')) for key in sorted(vars(args).keys()): logging.info('ARGS: ' + key + ': ' + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.info('batch size is automatically increased (%d -> %d)' % ( args.batch_size, args.batch_size * args.ngpu)) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) # Setup an optimizer if args.opt == 'adadelta': optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay) elif args.opt == 'adam': optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) elif args.opt == 'noam': from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error(f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux") raise e if args.opt == 'noam': model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer, opt_level=args.train_dtype) else: model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype) # read json data with open(args.train_json, 'rb') as f: train_json = json.load(f)['utts'] with open(args.valid_json, 'rb') as f: valid_json = json.load(f)['utts'] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset(train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) valid = make_batchset(valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout) load_tr = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': True} # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode='asr', load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={'train': False} # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list if args.n_iter_processes > 0: train_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(train, load_tr), batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingMultiprocessIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20) else: train_iter = ToggleableShufflingSerialIterator( TransformDataset(train, load_tr), batch_size=1, shuffle=not use_sortagrad) valid_iter = ToggleableShufflingSerialIterator( TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex) trainer = training.Trainer( updater, (args.epochs, 'epoch'), out=args.outdir) if use_sortagrad: trainer.extend(ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch')) # Resume from a snapshot if args.resume: logging.info('resumed from %s' % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device, args.ngpu)) # Save attention weight each epoch if args.num_save_attention > 0 and args.mtlalpha != 1.0: data = sorted(list(valid_json.items())[:args.num_save_attention], key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device) trainer.extend(att_reporter, trigger=(1, 'epoch')) else: att_reporter = None # Make a plot for training and validation values trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss', 'main/loss_ctc', 'validation/main/loss_ctc', 'main/loss_att', 'validation/main/loss_att'], 'epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'], 'epoch', file_name='acc.png')) trainer.extend(extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'], 'epoch', file_name='cer.png')) # Save best models trainer.extend(snapshot_object(model, 'model.loss.best'), trigger=training.triggers.MinValueTrigger('validation/main/loss')) if mtl_mode != 'ctc': trainer.extend(snapshot_object(model, 'model.acc.best'), trigger=training.triggers.MaxValueTrigger('validation/main/acc')) # save snapshot which contains model and optimizer states trainer.extend(torch_snapshot(), trigger=(1, 'epoch')) # epsilon decay in the optimizer if args.opt == 'adadelta': if args.criterion == 'acc' and mtl_mode != 'ctc': trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/acc', lambda best_value, current_value: best_value > current_value)) elif args.criterion == 'loss': trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) trainer.extend(adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( 'validation/main/loss', lambda best_value, current_value: best_value < current_value)) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport(trigger=(args.report_interval_iters, 'iteration'))) report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att', 'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att', 'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'] if args.opt == 'adadelta': trainer.extend(extensions.observe_value( 'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]), trigger=(args.report_interval_iters, 'iteration')) report_keys.append('eps') if args.report_cer: report_keys.append('validation/main/cer') if args.report_wer: report_keys.append('validation/main/wer') trainer.extend(extensions.PrintReport( report_keys), trigger=(args.report_interval_iters, 'iteration')) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), trigger=(args.report_interval_iters, "iteration")) # Run the training trainer.run() check_early_stop(trainer, args.epochs)
[docs]def recog(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, ASRInterface) model.recog_args = args # read rnnlm if args.rnnlm: rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) if getattr(rnnlm_args, "model_module", "default") != "default": raise ValueError("use '--api v2' option to decode with non-default language model") rnnlm = lm_pytorch.ClassifierWithState( lm_pytorch.RNNLM( len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.rnnlm, rnnlm) rnnlm.eval() else: rnnlm = None if args.word_rnnlm: rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) word_dict = rnnlm_args.char_list_dict char_dict = {x: i for i, x in enumerate(train_args.char_list)} word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM( len(word_dict), rnnlm_args.layer, rnnlm_args.unit)) torch_load(args.word_rnnlm, word_rnnlm) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm_pytorch.ClassifierWithState( extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info('gpu id: ' + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={'train': False}) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info('(%d/%d) decoding ' + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] if args.streaming_mode == 'window': logging.info('Using streaming recognizer with window size %d frames', args.streaming_window) se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) for i in range(0, feat.shape[0], args.streaming_window): logging.info('Feeding frames %d - %d', i, i + args.streaming_window) se2e.accept_input(feat[i:i + args.streaming_window]) logging.info('Running offline attention decoder') se2e.decode_with_attention_offline() logging.info('Offline attention decoder finished') nbest_hyps = se2e.retrieve_recognition() elif args.streaming_mode == 'segment': logging.info('Using streaming recognizer with threshold value %d', args.streaming_min_blank_dur) nbest_hyps = [] for n in range(args.nbest): nbest_hyps.append({'yseq': [], 'score': 0.0}) se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) r = np.prod(model.subsample) for i in range(0, feat.shape[0], r): hyps = se2e.accept_input(feat[i:i + r]) if hyps is not None: text = ''.join([train_args.char_list[int(x)] for x in hyps[0]['yseq'][1:-1] if int(x) != -1]) text = text.replace('\u2581', ' ').strip() # for SentencePiece text = text.replace(model.space, ' ') text = text.replace(model.blank, '') logging.info(text) for n in range(args.nbest): nbest_hyps[n]['yseq'].extend(hyps[n]['yseq']) nbest_hyps[n]['score'] += hyps[n]['score'] else: nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list) with open(args.result_label, 'wb') as f: f.write(json.dumps({'utts': new_js}, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
[docs]def enhance(args): """Dumping enhanced speech and mask. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # read training config idim, odim, train_args = get_model_conf(args.model, args.model_conf) # load trained model parameters logging.info('reading model parameters from ' + args.model) model_class = dynamic_import(train_args.model_module) model = model_class(idim, odim, train_args) assert isinstance(model, ASRInterface) torch_load(args.model, model) model.recog_args = args # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info('gpu id: ' + str(gpu_id)) model.cuda() # read json data with open(args.recog_json, 'rb') as f: js = json.load(f)['utts'] load_inputs_and_targets = LoadInputsAndTargets( mode='asr', load_output=False, sort_in_input_length=False, preprocess_conf=None # Apply pre_process in outer func ) if args.batchsize == 0: args.batchsize = 1 # Creates writers for outputs from the network if args.enh_wspecifier is not None: enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype) else: enh_writer = None # Creates a Transformation instance preprocess_conf = ( train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf) if preprocess_conf is not None: logging.info('Use preprocessing'.format(preprocess_conf)) transform = Transformation(preprocess_conf) else: transform = None # Creates a IStft instance istft = None frame_shift = args.istft_n_shift # Used for plot the spectrogram if args.apply_istft: if preprocess_conf is not None: # Read the conffile and find stft setting with open(preprocess_conf) as f: # Json format: e.g. # {"process": [{"type": "stft", # "win_length": 400, # "n_fft": 512, "n_shift": 160, # "window": "han"}, # {"type": "foo", ...}, ...]} conf = json.load(f) assert 'process' in conf, conf # Find stft setting for p in conf['process']: if p['type'] == 'stft': istft = IStft(win_length=p['win_length'], n_shift=p['n_shift'], window=p.get('window', 'hann')) logging.info('stft is found in {}. ' 'Setting istft config from it\n{}' .format(preprocess_conf, istft)) frame_shift = p['n_shift'] break if istft is None: # Set from command line arguments istft = IStft(win_length=args.istft_win_length, n_shift=args.istft_n_shift, window=args.istft_window) logging.info('Setting istft config from the command line args\n{}' .format(istft)) # sort data keys = list(js.keys()) feat_lens = [js[key]['input'][0]['shape'][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) num_images = 0 if not os.path.exists(args.image_dir): os.makedirs(args.image_dir) for names in grouper(args.batchsize, keys, None): batch = [(name, js[name]) for name in names] # May be in time region: (Batch, [Time, Channel]) org_feats = load_inputs_and_targets(batch)[0] if transform is not None: # May be in time-freq region: : (Batch, [Time, Channel, Freq]) feats = transform(org_feats, train=False) else: feats = org_feats with torch.no_grad(): enhanced, mask, ilens = model.enhance(feats) for idx, name in enumerate(names): # Assuming mask, feats : [Batch, Time, Channel. Freq] # enhanced : [Batch, Time, Freq] enh = enhanced[idx][:ilens[idx]] mas = mask[idx][:ilens[idx]] feat = feats[idx] # Plot spectrogram if args.image_dir is not None and num_images < args.num_images: import matplotlib.pyplot as plt num_images += 1 ref_ch = 0 plt.figure(figsize=(20, 10)) plt.subplot(4, 1, 1) plt.title('Mask [ref={}ch]'.format(ref_ch)) plot_spectrogram(plt, mas[:, ref_ch].T, fs=args.fs, mode='linear', frame_shift=frame_shift, bottom=False, labelbottom=False) plt.subplot(4, 1, 2) plt.title('Noisy speech [ref={}ch]'.format(ref_ch)) plot_spectrogram(plt, feat[:, ref_ch].T, fs=args.fs, mode='db', frame_shift=frame_shift, bottom=False, labelbottom=False) plt.subplot(4, 1, 3) plt.title('Masked speech [ref={}ch]'.format(ref_ch)) plot_spectrogram( plt, (feat[:, ref_ch] * mas[:, ref_ch]).T, frame_shift=frame_shift, fs=args.fs, mode='db', bottom=False, labelbottom=False) plt.subplot(4, 1, 4) plt.title('Enhanced speech') plot_spectrogram(plt, enh.T, fs=args.fs, mode='db', frame_shift=frame_shift) plt.savefig(os.path.join(args.image_dir, name + '.png')) plt.clf() # Write enhanced wave files if enh_writer is not None: if istft is not None: enh = istft(enh) else: enh = enh if args.keep_length: if len(org_feats[idx]) < len(enh): # Truncate the frames added by stft padding enh = enh[:len(org_feats[idx])] elif len(org_feats) > len(enh): padwidth = [(0, (len(org_feats[idx]) - len(enh)))] \ + [(0, 0)] * (enh.ndim - 1) enh = np.pad(enh, padwidth, mode='constant') if args.enh_filetype in ('sound', 'sound.hdf5'): enh_writer[name] = (args.fs, enh) else: # Hint: To dump stft_signal, mask or etc, # enh_filetype='hdf5' might be convenient. enh_writer[name] = enh if num_images >= args.num_images and enh_writer is None: logging.info('Breaking the process.') break