#!/usr/bin/env python
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import copy
import json
import logging
import math
import os
import sys
from chainer.datasets import TransformDataset
from chainer import reporter as reporter_module
from chainer import training
from chainer.training import extensions
from chainer.training.updater import StandardUpdater
import numpy as np
from tensorboardX import SummaryWriter
import torch
from torch.nn.parallel import data_parallel
from espnet.asr.asr_utils import adadelta_eps_decay
from espnet.asr.asr_utils import add_results_to_json
from espnet.asr.asr_utils import CompareValueTrigger
from espnet.asr.asr_utils import get_model_conf
from espnet.asr.asr_utils import plot_spectrogram
from espnet.asr.asr_utils import restore_snapshot
from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot
from espnet.asr.pytorch_backend.asr_init import load_trained_model
from espnet.asr.pytorch_backend.asr_init import load_trained_modules
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.e2e_asr import pad_list
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
from espnet.transform.spectrogram import IStft
from espnet.transform.transformation import Transformation
from espnet.utils.cli_writers import file_writer_helper
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator
from espnet.utils.training.iterators import ToggleableShufflingSerialIterator
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop
import matplotlib
matplotlib.use('Agg')
if sys.version_info[0] == 2:
from itertools import izip_longest as zip_longest
else:
from itertools import zip_longest as zip_longest
[docs]class CustomEvaluator(BaseEvaluator):
"""Custom Evaluator for Pytorch.
Args:
model (torch.nn.Module): The model to evaluate.
iterator (chainer.dataset.Iterator) : The train iterator.
target (link | dict[str, link]) :Link object or a dictionary of
links to evaluate. If this is just a link object, the link is
registered by the name ``'main'``.
converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (torch.device): The device used.
ngpu (int): The number of GPUs.
"""
def __init__(self, model, iterator, target, converter, device, ngpu=None):
super(CustomEvaluator, self).__init__(iterator, target)
self.model = model
self.converter = converter
self.device = device
if ngpu is not None:
self.ngpu = ngpu
elif device.type == "cpu":
self.ngpu = 0
else:
self.ngpu = 1
# The core part of the update routine can be customized by overriding
[docs] def evaluate(self):
"""Main evaluate routine for CustomEvaluator."""
iterator = self._iterators['main']
if self.eval_hook:
self.eval_hook(self)
if hasattr(iterator, 'reset'):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)
summary = reporter_module.DictSummary()
self.model.eval()
with torch.no_grad():
for batch in it:
observation = {}
with reporter_module.report_scope(observation):
# read scp files
# x: original json with loaded features
# will be converted to chainer variable later
x = self.converter(batch, self.device)
if self.ngpu == 0:
self.model(*x)
else:
# apex does not support torch.nn.DataParallel
data_parallel(self.model, x, range(self.ngpu))
summary.add(observation)
self.model.train()
return summary.compute_mean()
[docs]class CustomUpdater(StandardUpdater):
"""Custom Updater for Pytorch.
Args:
model (torch.nn.Module): The model to update.
grad_clip_threshold (float): The gradient clipping value to use.
train_iter (chainer.dataset.Iterator): The training iterator.
optimizer (torch.optim.optimizer): The training optimizer.
converter (espnet.asr.pytorch_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (torch.device): The device to use.
ngpu (int): The number of gpus to use.
use_apex (bool): The flag to use Apex in backprop.
"""
def __init__(self, model, grad_clip_threshold, train_iter,
optimizer, converter, device, ngpu, grad_noise=False, accum_grad=1, use_apex=False):
super(CustomUpdater, self).__init__(train_iter, optimizer)
self.model = model
self.grad_clip_threshold = grad_clip_threshold
self.converter = converter
self.device = device
self.ngpu = ngpu
self.accum_grad = accum_grad
self.forward_count = 0
self.grad_noise = grad_noise
self.iteration = 0
self.use_apex = use_apex
# The core part of the update routine can be customized by overriding.
[docs] def update_core(self):
"""Main update routine of the CustomUpdater."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator('main')
optimizer = self.get_optimizer('main')
# Get the next batch ( a list of json files)
batch = train_iter.next()
self.iteration += 1
x = self.converter(batch, self.device)
# Compute the loss at this time step and accumulate it
if self.ngpu == 0:
loss = self.model(*x).mean() / self.accum_grad
else:
# apex does not support torch.nn.DataParallel
loss = data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
if self.use_apex:
from apex import amp
# NOTE: for a compatibility with noam optimizer
opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# gradient noise injection
if self.grad_noise:
from espnet.asr.asr_utils import add_gradient_noise
add_gradient_noise(self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55)
loss.detach() # Truncate the graph
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip_threshold)
logging.info('grad norm={}'.format(grad_norm))
if math.isnan(grad_norm):
logging.warning('grad norm is nan. Do not update model.')
else:
optimizer.step()
optimizer.zero_grad()
[docs] def update(self):
self.update_core()
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.iteration += 1
[docs]class CustomConverter(object):
"""Custom batch converter for Pytorch.
Args:
subsampling_factor (int): The subsampling factor.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=torch.float32):
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
def __call__(self, batch, device):
"""Transforms a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[::self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == 'c':
xs_pad_real = pad_list(
[torch.from_numpy(x.real).float() for x in xs], 0).to(device, dtype=self.dtype)
xs_pad_imag = pad_list(
[torch.from_numpy(x.imag).float() for x in xs], 0).to(device, dtype=self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it.
xs_pad = {'real': xs_pad_real, 'imag': xs_pad_imag}
else:
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device, dtype=self.dtype)
ilens = torch.from_numpy(ilens).to(device)
# NOTE: this is for multi-task learning (e.g., speech translation)
ys_pad = pad_list([torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long()
for y in ys], self.ignore_id).to(device)
return xs_pad, ilens, ys_pad
[docs]def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning('cuda is not available')
# get input and output dimension info
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
logging.info('#input dims : ' + str(idim))
logging.info('#output dims: ' + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = 'ctc'
logging.info('Pure CTC mode')
elif args.mtlalpha == 0.0:
mtl_mode = 'att'
logging.info('Pure attention mode')
else:
mtl_mode = 'mtl'
logging.info('Multitask learning mode')
if args.enc_init is not None or args.dec_init is not None:
model = load_trained_modules(idim, odim, args)
else:
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args)
assert isinstance(model, ASRInterface)
subsampling_factor = model.subsample[0]
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + '/model.json'
with open(model_conf, 'wb') as f:
logging.info('writing a model config file to ' + model_conf)
f.write(json.dumps((idim, odim, vars(args)),
indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
for key in sorted(vars(args).keys()):
logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))
reporter = model.reporter
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.info('batch size is automatically increased (%d -> %d)' % (
args.batch_size, args.batch_size * args.ngpu))
args.batch_size *= args.ngpu
# set torch device
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
# Setup an optimizer
if args.opt == 'adadelta':
optimizer = torch.optim.Adadelta(
model.parameters(), rho=0.95, eps=args.eps,
weight_decay=args.weight_decay)
elif args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(),
weight_decay=args.weight_decay)
elif args.opt == 'noam':
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
optimizer = get_std_opt(model, args.adim, args.transformer_warmup_steps, args.transformer_lr)
else:
raise NotImplementedError("unknown optimizer: " + args.opt)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux")
raise e
if args.opt == 'noam':
model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer, opt_level=args.train_dtype)
else:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
converter = CustomConverter(subsampling_factor=subsampling_factor, dtype=dtype)
# read json data
with open(args.train_json, 'rb') as f:
train_json = json.load(f)['utts']
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(train_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout)
valid = make_batchset(valid_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout)
load_tr = LoadInputsAndTargets(
mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': True} # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': False} # Switch the mode of preprocessing
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
if args.n_iter_processes > 0:
train_iter = ToggleableShufflingMultiprocessIterator(
TransformDataset(train, load_tr),
batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20,
shuffle=not use_sortagrad)
valid_iter = ToggleableShufflingMultiprocessIterator(
TransformDataset(valid, load_cv),
batch_size=1, repeat=False, shuffle=False,
n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
else:
train_iter = ToggleableShufflingSerialIterator(
TransformDataset(train, load_tr),
batch_size=1, shuffle=not use_sortagrad)
valid_iter = ToggleableShufflingSerialIterator(
TransformDataset(valid, load_cv),
batch_size=1, repeat=False, shuffle=False)
# Set up a trainer
updater = CustomUpdater(
model, args.grad_clip, train_iter, optimizer,
converter, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex)
trainer = training.Trainer(
updater, (args.epochs, 'epoch'), out=args.outdir)
if use_sortagrad:
trainer.extend(ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, 'epoch'))
# Resume from a snapshot
if args.resume:
logging.info('resumed from %s' % args.resume)
torch_resume(args.resume, trainer)
# Evaluate the model with the test dataset for each epoch
trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device, args.ngpu))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(list(valid_json.items())[:args.num_save_attention],
key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn, data, args.outdir + "/att_ws",
converter=converter, transform=load_cv, device=device)
trainer.extend(att_reporter, trigger=(1, 'epoch'))
else:
att_reporter = None
# Make a plot for training and validation values
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
'main/loss_ctc', 'validation/main/loss_ctc',
'main/loss_att', 'validation/main/loss_att'],
'epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
'epoch', file_name='acc.png'))
trainer.extend(extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
'epoch', file_name='cer.png'))
# Save best models
trainer.extend(snapshot_object(model, 'model.loss.best'),
trigger=training.triggers.MinValueTrigger('validation/main/loss'))
if mtl_mode != 'ctc':
trainer.extend(snapshot_object(model, 'model.acc.best'),
trigger=training.triggers.MaxValueTrigger('validation/main/acc'))
# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))
# epsilon decay in the optimizer
if args.opt == 'adadelta':
if args.criterion == 'acc' and mtl_mode != 'ctc':
trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
elif args.criterion == 'loss':
trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport(trigger=(args.report_interval_iters, 'iteration')))
report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
'main/acc', 'validation/main/acc', 'main/cer_ctc', 'validation/main/cer_ctc',
'elapsed_time']
if args.opt == 'adadelta':
trainer.extend(extensions.observe_value(
'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
trigger=(args.report_interval_iters, 'iteration'))
report_keys.append('eps')
if args.report_cer:
report_keys.append('validation/main/cer')
if args.report_wer:
report_keys.append('validation/main/wer')
trainer.extend(extensions.PrintReport(
report_keys), trigger=(args.report_interval_iters, 'iteration'))
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
trigger=(args.report_interval_iters, "iteration"))
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
[docs]def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError("use '--api v2' option to decode with non-default language model")
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
rnnlm.predictor, word_dict, char_dict))
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
word_dict, char_dict))
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
# read json data
with open(args.recog_json, 'rb') as f:
js = json.load(f)['utts']
new_js = {}
load_inputs_and_targets = LoadInputsAndTargets(
mode='asr', load_output=False, sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={'train': False})
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
if args.streaming_mode == 'window':
logging.info('Using streaming recognizer with window size %d frames', args.streaming_window)
se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
for i in range(0, feat.shape[0], args.streaming_window):
logging.info('Feeding frames %d - %d', i, i + args.streaming_window)
se2e.accept_input(feat[i:i + args.streaming_window])
logging.info('Running offline attention decoder')
se2e.decode_with_attention_offline()
logging.info('Offline attention decoder finished')
nbest_hyps = se2e.retrieve_recognition()
elif args.streaming_mode == 'segment':
logging.info('Using streaming recognizer with threshold value %d', args.streaming_min_blank_dur)
nbest_hyps = []
for n in range(args.nbest):
nbest_hyps.append({'yseq': [], 'score': 0.0})
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
r = np.prod(model.subsample)
for i in range(0, feat.shape[0], r):
hyps = se2e.accept_input(feat[i:i + r])
if hyps is not None:
text = ''.join([train_args.char_list[int(x)]
for x in hyps[0]['yseq'][1:-1] if int(x) != -1])
text = text.replace('\u2581', ' ').strip() # for SentencePiece
text = text.replace(model.space, ' ')
text = text.replace(model.blank, '')
logging.info(text)
for n in range(args.nbest):
nbest_hyps[n]['yseq'].extend(hyps[n]['yseq'])
nbest_hyps[n]['score'] += hyps[n]['score']
else:
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(js[name], nbest_hyps, train_args.char_list)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
# sort data if batchsize > 1
keys = list(js.keys())
if args.batchsize > 1:
feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
batch = [(name, js[name]) for name in names]
feats = load_inputs_and_targets(batch)[0]
nbest_hyps = model.recognize_batch(feats, args, train_args.char_list, rnnlm=rnnlm)
for i, nbest_hyp in enumerate(nbest_hyps):
name = names[i]
new_js[name] = add_results_to_json(js[name], nbest_hyp, train_args.char_list)
with open(args.result_label, 'wb') as f:
f.write(json.dumps({'utts': new_js}, indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
[docs]def enhance(args):
"""Dumping enhanced speech and mask.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# load trained model parameters
logging.info('reading model parameters from ' + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
# read json data
with open(args.recog_json, 'rb') as f:
js = json.load(f)['utts']
load_inputs_and_targets = LoadInputsAndTargets(
mode='asr', load_output=False, sort_in_input_length=False,
preprocess_conf=None # Apply pre_process in outer func
)
if args.batchsize == 0:
args.batchsize = 1
# Creates writers for outputs from the network
if args.enh_wspecifier is not None:
enh_writer = file_writer_helper(args.enh_wspecifier,
filetype=args.enh_filetype)
else:
enh_writer = None
# Creates a Transformation instance
preprocess_conf = (
train_args.preprocess_conf if args.preprocess_conf is None
else args.preprocess_conf)
if preprocess_conf is not None:
logging.info('Use preprocessing'.format(preprocess_conf))
transform = Transformation(preprocess_conf)
else:
transform = None
# Creates a IStft instance
istft = None
frame_shift = args.istft_n_shift # Used for plot the spectrogram
if args.apply_istft:
if preprocess_conf is not None:
# Read the conffile and find stft setting
with open(preprocess_conf) as f:
# Json format: e.g.
# {"process": [{"type": "stft",
# "win_length": 400,
# "n_fft": 512, "n_shift": 160,
# "window": "han"},
# {"type": "foo", ...}, ...]}
conf = json.load(f)
assert 'process' in conf, conf
# Find stft setting
for p in conf['process']:
if p['type'] == 'stft':
istft = IStft(win_length=p['win_length'],
n_shift=p['n_shift'],
window=p.get('window', 'hann'))
logging.info('stft is found in {}. '
'Setting istft config from it\n{}'
.format(preprocess_conf, istft))
frame_shift = p['n_shift']
break
if istft is None:
# Set from command line arguments
istft = IStft(win_length=args.istft_win_length,
n_shift=args.istft_n_shift,
window=args.istft_window)
logging.info('Setting istft config from the command line args\n{}'
.format(istft))
# sort data
keys = list(js.keys())
feat_lens = [js[key]['input'][0]['shape'][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
num_images = 0
if not os.path.exists(args.image_dir):
os.makedirs(args.image_dir)
for names in grouper(args.batchsize, keys, None):
batch = [(name, js[name]) for name in names]
# May be in time region: (Batch, [Time, Channel])
org_feats = load_inputs_and_targets(batch)[0]
if transform is not None:
# May be in time-freq region: : (Batch, [Time, Channel, Freq])
feats = transform(org_feats, train=False)
else:
feats = org_feats
with torch.no_grad():
enhanced, mask, ilens = model.enhance(feats)
for idx, name in enumerate(names):
# Assuming mask, feats : [Batch, Time, Channel. Freq]
# enhanced : [Batch, Time, Freq]
enh = enhanced[idx][:ilens[idx]]
mas = mask[idx][:ilens[idx]]
feat = feats[idx]
# Plot spectrogram
if args.image_dir is not None and num_images < args.num_images:
import matplotlib.pyplot as plt
num_images += 1
ref_ch = 0
plt.figure(figsize=(20, 10))
plt.subplot(4, 1, 1)
plt.title('Mask [ref={}ch]'.format(ref_ch))
plot_spectrogram(plt, mas[:, ref_ch].T, fs=args.fs,
mode='linear', frame_shift=frame_shift,
bottom=False, labelbottom=False)
plt.subplot(4, 1, 2)
plt.title('Noisy speech [ref={}ch]'.format(ref_ch))
plot_spectrogram(plt, feat[:, ref_ch].T, fs=args.fs,
mode='db', frame_shift=frame_shift,
bottom=False, labelbottom=False)
plt.subplot(4, 1, 3)
plt.title('Masked speech [ref={}ch]'.format(ref_ch))
plot_spectrogram(
plt, (feat[:, ref_ch] * mas[:, ref_ch]).T,
frame_shift=frame_shift,
fs=args.fs, mode='db', bottom=False, labelbottom=False)
plt.subplot(4, 1, 4)
plt.title('Enhanced speech')
plot_spectrogram(plt, enh.T, fs=args.fs,
mode='db', frame_shift=frame_shift)
plt.savefig(os.path.join(args.image_dir, name + '.png'))
plt.clf()
# Write enhanced wave files
if enh_writer is not None:
if istft is not None:
enh = istft(enh)
else:
enh = enh
if args.keep_length:
if len(org_feats[idx]) < len(enh):
# Truncate the frames added by stft padding
enh = enh[:len(org_feats[idx])]
elif len(org_feats) > len(enh):
padwidth = [(0, (len(org_feats[idx]) - len(enh)))] \
+ [(0, 0)] * (enh.ndim - 1)
enh = np.pad(enh, padwidth, mode='constant')
if args.enh_filetype in ('sound', 'sound.hdf5'):
enh_writer[name] = (args.fs, enh)
else:
# Hint: To dump stft_signal, mask or etc,
# enh_filetype='hdf5' might be convenient.
enh_writer[name] = enh
if num_images >= args.num_images and enh_writer is None:
logging.info('Breaking the process.')
break