# encoding: utf-8
from argparse import Namespace
from distutils.util import strtobool
import logging
import math
import numpy as np
import six
import chainer
from chainer import reporter
import chainer.functions as F
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.chainer_backend import ctc
from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention
from espnet.nets.chainer_backend.transformer.decoder import Decoder
from espnet.nets.chainer_backend.transformer.encoder import Encoder
from espnet.nets.chainer_backend.transformer.label_smoothing_loss import LabelSmoothingLoss
from espnet.nets.chainer_backend.transformer.plot import PlotAttentionReport
from espnet.nets.ctc_prefix_score import CTCPrefixScore
CTC_SCORING_RATIO = 1.5
MAX_DECODER_OUTPUT = 5
[docs]class E2E(ASRInterface, chainer.Chain):
"""E2E module.
Args:
idim (int): Dimension of inputs.
odim (int): Dimension of outputs.
args (Namespace): Training config.
flag_return (bool): If True, then return value of `forward()`
would be tuple of (loss, loss_ctc, loss_att, acc)
"""
[docs] @staticmethod
def add_arguments(parser):
group = parser.add_argument_group("transformer model setting")
group.add_argument("--transformer-init", type=str, default="pytorch",
help='how to initialize transformer parameters')
group.add_argument("--transformer-input-layer", type=str, default="conv2d",
choices=["conv2d", "linear", "embed"],
help='transformer input layer type')
group.add_argument('--transformer-attn-dropout-rate', default=None, type=float,
help='dropout in transformer attention. use --dropout-rate if None is set')
group.add_argument('--transformer-lr', default=10.0, type=float,
help='Initial value of learning rate')
group.add_argument('--transformer-warmup-steps', default=25000, type=int,
help='optimizer warmup steps')
group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool,
help='normalize loss by length')
group.add_argument('--dropout-rate', default=0.0, type=float,
help='Dropout rate for the encoder')
# Encoder
group.add_argument('--elayers', default=4, type=int,
help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)')
group.add_argument('--eunits', '-u', default=300, type=int,
help='Number of encoder hidden units')
# Attention
group.add_argument('--adim', default=320, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--aheads', default=4, type=int,
help='Number of heads for multi head attention')
# Decoder
group.add_argument('--dlayers', default=1, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=320, type=int,
help='Number of decoder hidden units')
return parser
def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True):
chainer.Chain.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]"
if args.transformer_attn_dropout_rate is None:
self.dropout = args.dropout_rate
else:
self.dropout = args.transformer_attn_dropout_rate
self.use_label_smoothing = False
self.char_list = args.char_list
self.space = args.sym_space
self.blank = args.sym_blank
self.scale_emb = args.adim ** 0.5
self.sos = odim - 1
self.eos = odim - 1
self.subsample = [0]
self.ignore_id = ignore_id
self.reset_parameters(args)
with self.init_scope():
self.encoder = Encoder(idim, args, initialW=self.initialW, initial_bias=self.initialB)
self.decoder = Decoder(odim, args, initialW=self.initialW, initial_bias=self.initialB)
self.criterion = LabelSmoothingLoss(args.lsm_weight, len(args.char_list),
args.transformer_length_normalized_loss)
if args.mtlalpha > 0.0:
if args.ctc_type == 'builtin':
logging.info("Using chainer CTC implementation")
self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate)
elif args.ctc_type == 'warpctc':
logging.info("Using warpctc CTC implementation")
self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate)
else:
raise ValueError('ctc_type must be "builtin" or "warpctc": {}'
.format(args.ctc_type))
else:
self.ctc = None
self.dims = args.adim
self.odim = odim
self.flag_return = flag_return
if args.report_cer or args.report_wer:
from espnet.nets.e2e_asr_common import ErrorCalculator
self.error_calculator = ErrorCalculator(args.char_list,
args.sym_space, args.sym_blank,
args.report_cer, args.report_wer)
else:
self.error_calculator = None
if 'Namespace' in str(type(args)):
self.verbose = 0 if 'verbose' not in args else args.verbose
else:
self.verbose = 0 if args.verbose is None else args.verbose
[docs] def reset_parameters(self, args):
"""Initialize the Weight according to the give initialize-type.
Args:
args (Namespace): Transformer config.
"""
type_init = args.transformer_init
if type_init == 'lecun_uniform':
logging.info('Using LeCunUniform as Parameter initializer')
self.initialW = chainer.initializers.LeCunUniform
elif type_init == 'lecun_normal':
logging.info('Using LeCunNormal as Parameter initializer')
self.initialW = chainer.initializers.LeCunNormal
elif type_init == 'gorot_uniform':
logging.info('Using GlorotUniform as Parameter initializer')
self.initialW = chainer.initializers.GlorotUniform
elif type_init == 'gorot_normal':
logging.info('Using GlorotNormal as Parameter initializer')
self.initialW = chainer.initializers.GlorotNormal
elif type_init == 'he_uniform':
logging.info('Using HeUniform as Parameter initializer')
self.initialW = chainer.initializers.HeUniform
elif type_init == 'he_normal':
logging.info('Using HeNormal as Parameter initializer')
self.initialW = chainer.initializers.HeNormal
elif type_init == 'pytorch':
logging.info('Using Pytorch initializer')
self.initialW = chainer.initializers.Uniform
else:
logging.info('Using Chainer default as Parameter initializer')
self.initialW = chainer.initializers.Uniform
self.initialB = chainer.initializers.Uniform
[docs] def make_attention_mask(self, source_block, target_block):
mask = (target_block[:, None, :] >= 0) * \
(source_block[:, :, None] >= 0)
# (batch, source_length, target_length)
return mask
[docs] def make_history_mask(self, block):
batch, length = block.shape
arange = self.xp.arange(length)
history_mask = (arange[None, ] <= arange[:, None])[None, ]
history_mask = self.xp.broadcast_to(
history_mask, (batch, length, length))
return history_mask
[docs] def forward(self, xs, ilens, ys_pad, calculate_attentions=False):
"""E2E forward propagation.
Args:
xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax)
ilens (chainer.Variable): Batch of length of each input batch. (B,)
ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim)
calculate_attentions (bool): If true, return value is the output of encoder.
Returns:
float: Training loss.
float (optional): Training loss for ctc.
float (optional): Training loss for attention.
float (optional): Accuracy.
chainer.Variable (Optional): Output of the encoder.
"""
xp = self.xp
with chainer.no_backprop_mode():
xs = xp.array(xs)
eos = np.array([self.eos], 'i')
sos = np.array([self.sos], 'i')
ys_out = [F.concat([y, eos], axis=0) for y in ys_pad]
ys = [F.concat([sos, y], axis=0) for y in ys_pad]
ys = F.pad_sequence(ys, padding=self.eos)
ys_out = F.pad_sequence(ys_out, padding=-1)
ys = xp.array(ys.data)
ys_out = chainer.Variable(xp.array(ys_out.data))
ys_pad_cpu = [y.astype(np.int32) for y in ys_pad]
logging.info(self.__class__.__name__ + ' input lengths: ' + str(ilens))
# Encode Sources
# xs: utt x frame x dim
logging.debug('Init size: ' + str(xs.shape))
logging.debug('Out size: ' + str(ys.shape))
# Dims along enconder and decoder: batchsize * length x dims
xs, x_mask, ilens = self.encoder(xs, ilens)
logging.info(self.__class__.__name__ + ' input lengths: ' + str(ilens))
logging.info(self.__class__.__name__ + ' output lengths: ' + str(xp.array([y.shape[0] for y in ys_out])))
xy_mask = self.make_attention_mask(ys, xp.array(x_mask))
yy_mask = self.make_attention_mask(ys, ys)
yy_mask *= self.make_history_mask(ys)
batch, length = ys.shape
ys = self.decoder(ys, yy_mask, xs, xy_mask)
if calculate_attentions:
return xs
# Compute Attention Loss
loss_att = self.criterion(ys, ys_out, batch, length)
acc = F.accuracy(ys, ys_out.reshape((-1)).data, ignore_label=self.ignore_id)
# Compute CTC Loss and CER CTC
cer_ctc = None
if self.ctc is None:
loss_ctc = None
else:
xs = xs.reshape(batch, -1, self.dims)
xs = [xs[i, :ilens[i], :] for i in range(len(ilens))]
loss_ctc = self.ctc(xs, ys_pad_cpu)
if self.error_calculator is not None:
with chainer.no_backprop_mode():
ys_hat = chainer.backends.cuda.to_cpu(self.ctc.argmax(xs).data)
cer_ctc = self.error_calculator(ys_hat, ys_pad_cpu, is_ctc=True)
# Compute cer/wer
with chainer.no_backprop_mode():
y_hats = ys.reshape((batch, length, -1))
y_hats = chainer.backends.cuda.to_cpu(F.argmax(y_hats, axis=2).data)
if chainer.config.train or self.error_calculator is None:
cer, wer = None, None
else:
cer, wer = self.error_calculator(y_hats, ys_pad_cpu)
# Print Output
if chainer.config.train and (self.verbose > 0 and self.char_list is not None):
for i, y_hat in enumerate(y_hats):
y_true = chainer.backends.cuda.to_cpu(ys_pad[i].data)
if i == MAX_DECODER_OUTPUT:
break
eos_true = np.where(y_true == -1)[0]
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
seq_true = [self.char_list[int(idx)] for idx in y_true[y_true != -1]]
seq_hat = "".join(seq_hat).replace(self.space, ' ')
seq_true = "".join(seq_true).replace(self.space, ' ')
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
alpha = self.mtlalpha
if alpha == 0.0:
self.loss = loss_att
loss_att_data = loss_att.data
loss_ctc_data = None
elif alpha == 1.0:
self.loss = loss_ctc
loss_att_data = None
loss_ctc_data = loss_ctc.data
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = loss_att.data
loss_ctc_data = loss_ctc.data
loss_data = self.loss.data
if not math.isnan(loss_data):
reporter.report({'loss_ctc': loss_ctc_data}, self)
reporter.report({'loss_att': loss_att_data}, self)
reporter.report({'acc': acc}, self)
reporter.report({'cer_ctc': cer_ctc}, self)
reporter.report({'cer': cer}, self)
reporter.report({'wer': wer}, self)
logging.info('mtl loss:' + str(loss_data))
reporter.report({'loss': loss_data}, self)
else:
logging.warning('loss (=%f) is not correct', loss_data)
if self.flag_return:
loss_ctc = None
return self.loss, loss_ctc, loss_att, acc
else:
return self.loss
[docs] def recognize(self, x_block, recog_args, char_list=None, rnnlm=None):
"""E2E beam search.
Args:
x (ndarray): Input acouctic feature (B, T, D) or (T, D).
recog_args (Namespace): Argment namespace contraining options.
char_list (List[str]): List of characters.
rnnlm (torch.nn.Module): Language model module defined at
`espnet.lm.chainer_backend.lm`.
Returns:
List: N-best decoding results.
"""
with chainer.no_backprop_mode(), chainer.using_config('train', False):
# 1. encoder
ilens = [x_block.shape[0]]
batch = len(ilens)
xs, _, _ = self.encoder(x_block[None, :, :], ilens)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(xs.reshape(batch, -1, self.dims)).data[0]
else:
lpz = None
# 2. decoder
if recog_args.lm_weight == 0.0:
rnnlm = None
y = self.recognize_beam(xs, lpz, recog_args, char_list, rnnlm)
return y
[docs] def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None):
"""beam search implementation
:param h:
:param lpz:
:param recog_args:
:param char_list:
:param rnnlm:
:return:
"""
logging.info('input lengths: ' + str(h.shape[0]))
# initialization
xp = self.xp
h_mask = xp.ones((1, h.shape[0]))
batch = 1
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
# prepare sos
y = self.sos
if recog_args.maxlenratio == 0:
maxlen = h.shape[0]
else:
maxlen = max(1, int(recog_args.maxlenratio * h.shape[0]))
minlen = int(recog_args.minlenratio * h.shape[0])
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y]}
if lpz is not None:
ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp)
hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
hyp['ctc_score_prev'] = 0.0
if ctc_weight != 1.0:
# pre-pruning based on attention scores
from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz.shape[-1]
hyps = [hyp]
ended_hyps = []
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
ys = F.expand_dims(xp.array(hyp['yseq']), axis=0).data
yy_mask = self.make_attention_mask(ys, ys)
yy_mask *= self.make_history_mask(ys)
xy_mask = self.make_attention_mask(ys, h_mask)
out = self.decoder(ys, yy_mask, h, xy_mask).reshape(batch, -1, self.odim)
# get nbest local scores and their ids
local_att_scores = F.log_softmax(out[:, -1], axis=-1).data
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], hyp['yseq'][i])
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz is not None:
local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam]
ctc_scores, ctc_states = ctc_prefix_score(hyp['yseq'], local_best_ids, hyp['ctc_state_prev'])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \
+ ctc_weight * (ctc_scores - hyp['ctc_score_prev'])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids]
joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
local_best_scores = local_scores[:, joint_best_ids]
local_best_ids = local_best_ids[joint_best_ids]
else:
local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
local_best_scores = local_scores[:, local_best_ids]
for j in six.moves.range(beam):
new_hyp = {}
new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j])
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz is not None:
new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[j]]
new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[j]]
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypothesis: ' + str(len(hyps)))
if char_list is not None:
logging.debug(
'best hypo: ' + ''.join([
char_list[int(x)] for x in hyps[0]['yseq'][1:]]) + ' score: ' + str(hyps[0]['score']))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last postion in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypothes to a final list, and removed them from current hypothes
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
from espnet.nets.e2e_asr_common import end_detect
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remained hypothes: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
if char_list is not None:
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True) # [:min(len(ended_hyps), recog_args.nbest)]
logging.debug(nbest_hyps)
# check number of hypotheis
if len(nbest_hyps) == 0:
logging.warn('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy becasuse Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
# remove sos
return nbest_hyps
[docs] def calculate_all_attentions(self, xs, ilens, ys):
"""E2E attention calculation.
Args:
xs_pad (List[tuple()]): List of padded input sequences. [(T1, idim), (T2, idim), ...]
ilens (ndarray): Batch of lengths of input sequences. (B)
ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...]
Returns:
float ndarray: Attention weights. (B, Lmax, Tmax)
"""
with chainer.no_backprop_mode():
results = self(xs, ilens, ys, calculate_attentions=True) # NOQA
ret = dict()
for name, m in self.namedlinks():
if isinstance(m, MultiHeadAttention):
var = m.attn
var.to_cpu()
_name = name[1:].replace('/', '_')
ret[_name] = var.data
return ret
@property
def attention_plot_class(self):
return PlotAttentionReport