import logging
import random
import six
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
import espnet.nets.chainer_backend.deterministic_embed_id as DL
from argparse import Namespace
from espnet.nets.ctc_prefix_score import CTCPrefixScore
from espnet.nets.e2e_asr_common import end_detect
CTC_SCORING_RATIO = 1.5
MAX_DECODER_OUTPUT = 5
[docs]class Decoder(chainer.Chain):
"""Decoder layer.
Args:
eprojs (int): Dimension of input variables from encoder.
odim (int): The output dimension.
dtype (str): Decoder type.
dlayers (int): Number of layers for decoder.
dunits (int): Dimension of input vector of decoder.
sos (int): Number to indicate the start of sequences.
eos (int): Number to indicate the end of sequences.
att (Module): Attention module defined at `espnet.espnet.nets.chainer_backend.attentions`.
verbose (int): Verbosity level.
char_list (List[str]): List of all charactors.
labeldist (numpy.array): Distributed array of counted transcript length.
lsm_weight (float): Weight to use when calculating the training loss.
sampling_probability (float): Threshold for scheduled sampling.
"""
def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, eos, att, verbose=0,
char_list=None, labeldist=None, lsm_weight=0., sampling_probability=0.0):
super(Decoder, self).__init__()
with self.init_scope():
self.embed = DL.EmbedID(odim, dunits)
self.rnn0 = L.StatelessLSTM(dunits + eprojs, dunits) if dtype == "lstm" \
else L.StatelessGRU(dunits + eprojs, dunits)
for l in six.moves.range(1, dlayers):
setattr(self, 'rnn%d' % l,
L.StatelessLSTM(dunits, dunits) if dtype == "lstm" else L.StatelessGRU(dunits, dunits))
self.output = L.Linear(dunits, odim)
self.dtype = dtype
self.loss = None
self.att = att
self.dlayers = dlayers
self.dunits = dunits
self.sos = sos
self.eos = eos
self.verbose = verbose
self.char_list = char_list
# for label smoothing
self.labeldist = labeldist
self.vlabeldist = None
self.lsm_weight = lsm_weight
self.sampling_probability = sampling_probability
[docs] def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
c_list[0], z_list[0] = self.rnn0(c_prev[0], z_prev[0], ey)
for l in six.moves.range(1, self.dlayers):
c_list[l], z_list[l] = self['rnn%d' % l](c_prev[l], z_prev[l], z_list[l - 1])
else:
if z_prev[0] is None:
xp = self.xp
with chainer.backends.cuda.get_device_from_id(self._device_id):
z_prev[0] = chainer.Variable(
xp.zeros((ey.shape[0], self.dunits), dtype=ey.dtype))
z_list[0] = self.rnn0(z_prev[0], ey)
for l in six.moves.range(1, self.dlayers):
if z_prev[l] is None:
xp = self.xp
with chainer.backends.cuda.get_device_from_id(self._device_id):
z_prev[l] = chainer.Variable(
xp.zeros((z_list[l - 1].shape[0], self.dunits), dtype=z_list[l - 1].dtype))
z_list[l] = self['rnn%d' % l](z_prev[l], z_list[l - 1])
return z_list, c_list
def __call__(self, hs, ys):
"""Core function of Decoder layer.
Args:
hs (list of chainer.Variable | N-dimension array): Input variable from encoder.
ys (list of chainer.Variable | N-dimension array): Input variable of decoder.
Returns:
chainer.Variable: A variable holding a scalar array of the training loss.
chainer.Variable: A variable holding a scalar array of the accuracy.
"""
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = self.xp.array([self.eos], 'i')
sos = self.xp.array([self.sos], 'i')
ys_in = [F.concat([sos, y], axis=0) for y in ys]
ys_out = [F.concat([y, eos], axis=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
pad_ys_in = F.pad_sequence(ys_in, padding=self.eos)
pad_ys_out = F.pad_sequence(ys_out, padding=-1)
# get dim, length info
batch = pad_ys_out.shape[0]
olength = pad_ys_out.shape[1]
logging.info(self.__class__.__name__ + ' input lengths: ' + str(self.xp.array([h.shape[0] for h in hs])))
logging.info(self.__class__.__name__ + ' output lengths: ' + str(self.xp.array([y.shape[0] for y in ys_out])))
# initialization
c_list = [None] # list of cell state of each layer
z_list = [None] # list of hidden state of each layer
for _ in six.moves.range(1, self.dlayers):
c_list.append(None)
z_list.append(None)
att_w = None
z_all = []
self.att.reset() # reset pre-computation of h
# pre-computation of embedding
eys = self.embed(pad_ys_in) # utt x olen x zdim
eys = F.separate(eys, axis=1)
# loop for an output sequence
for i in six.moves.range(olength):
att_c, att_w = self.att(hs, z_list[0], att_w)
if i > 0 and random.random() < self.sampling_probability:
logging.info(' scheduled sampling ')
z_out = self.output(z_all[-1])
z_out = F.argmax(F.log_softmax(z_out), axis=1)
z_out = self.embed(z_out)
ey = F.hstack((z_out, att_c)) # utt x (zdim + hdim)
else:
ey = F.hstack((eys[i], att_c)) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
z_all.append(z_list[-1])
z_all = F.reshape(F.stack(z_all, axis=1),
(batch * olength, self.dunits))
# compute loss
y_all = self.output(z_all)
self.loss = F.softmax_cross_entropy(y_all, F.flatten(pad_ys_out))
# -1: eos, which is removed in the loss computation
self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
acc = F.accuracy(y_all, F.flatten(pad_ys_out), ignore_label=-1)
logging.info('att loss:' + str(self.loss.data))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
y_hat = F.reshape(y_all, (batch, olength, -1))
y_true = pad_ys_out
for (i, y_hat_), y_true_ in zip(enumerate(y_hat.data), y_true.data):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = self.xp.argmax(y_hat_[y_true_ != -1], axis=1)
idx_true = y_true_[y_true_ != -1]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat).replace('<space>', ' ')
seq_true = "".join(seq_true).replace('<space>', ' ')
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = chainer.Variable(self.xp.asarray(self.labeldist))
loss_reg = - F.sum(F.scale(F.log_softmax(y_all), self.vlabeldist, axis=1)) / len(ys_in)
self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc
[docs] def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None):
"""Beam search implementation.
Args:
h (chainer.Variable): One of the output from the encoder.
lpz (chainer.Variable | None): Result of net propagation.
recog_args (Namespace): The argument.
char_list (List[str]): List of all charactors.
rnnlm (Module): RNNLM module. Defined at `espnet.lm.chainer_backend.lm`
Returns:
List[Dict[str,Any]]: Result of recognition.
"""
logging.info('input lengths: ' + str(h.shape[0]))
# initialization
c_list = [None] # list of cell state of each layer
z_list = [None] # list of hidden state of each layer
for _ in six.moves.range(1, self.dlayers):
c_list.append(None)
z_list.append(None)
a = None
self.att.reset() # reset pre-computation of h
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
# preprate sos
y = self.xp.full(1, self.sos, 'i')
if recog_args.maxlenratio == 0:
maxlen = h.shape[0]
else:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * h.shape[0]))
minlen = int(recog_args.minlenratio * h.shape[0])
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
if lpz is not None:
ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp)
hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
hyp['ctc_score_prev'] = 0.0
if ctc_weight != 1.0:
# pre-pruning based on attention scores
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz.shape[-1]
hyps = [hyp]
ended_hyps = []
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
ey = self.embed(hyp['yseq'][i]) # utt list (1) x zdim
att_c, att_w = self.att([h], hyp['z_prev'][0], hyp['a_prev'])
ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev'])
# get nbest local scores and their ids
local_att_scores = F.log_softmax(self.output(z_list[-1])).data
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], hyp['yseq'][i])
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz is not None:
local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:ctc_beam]
ctc_scores, ctc_states = ctc_prefix_score(hyp['yseq'], local_best_ids, hyp['ctc_state_prev'])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids] \
+ ctc_weight * (ctc_scores - hyp['ctc_score_prev'])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids]
joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
local_best_scores = local_scores[:, joint_best_ids]
local_best_ids = local_best_ids[joint_best_ids]
else:
local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][:beam]
local_best_scores = local_scores[:, local_best_ids]
for j in six.moves.range(beam):
new_hyp = {}
# do not copy {z,c}_list directly
new_hyp['z_prev'] = z_list[:]
new_hyp['c_prev'] = c_list[:]
new_hyp['a_prev'] = att_w
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = self.xp.full(
1, local_best_ids[j], 'i')
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz is not None:
new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[j]]
new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[j]]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypotheses: ' + str(len(hyps)))
logging.debug('best hypo: ' + ''.join([char_list[int(x)]
for x in hyps[0]['yseq'][1:]]).replace('<space>', ' '))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last position in the loop')
for hyp in hyps:
hyp['yseq'].append(self.xp.full(1, self.eos, 'i'))
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remaining hypotheses: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
for hyp in hyps:
logging.debug('hypo: ' + ''.join([char_list[int(x)]
for x in hyp['yseq'][1:]]).replace('<space>', ' '))
logging.debug('number of ended hypotheses: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
return nbest_hyps
[docs] def calculate_all_attentions(self, hs, ys):
"""Calculate all of attentions.
Args:
hs (list of chainer.Variable | N-dimensional array): Input variable from encoder.
ys (list of chainer.Variable | N-dimensional array): Input variable of decoder.
Returns:
chainer.Variable: List of attention weights.
"""
# prepare input and output word sequences with sos/eos IDs
eos = self.xp.array([self.eos], 'i')
sos = self.xp.array([self.sos], 'i')
ys_in = [F.concat([sos, y], axis=0) for y in ys]
ys_out = [F.concat([y, eos], axis=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
pad_ys_in = F.pad_sequence(ys_in, padding=self.eos)
pad_ys_out = F.pad_sequence(ys_out, padding=-1)
# get length info
olength = pad_ys_out.shape[1]
# initialization
c_list = [None] # list of cell state of each layer
z_list = [None] # list of hidden state of each layer
for _ in six.moves.range(1, self.dlayers):
c_list.append(None)
z_list.append(None)
att_w = None
att_ws = []
self.att.reset() # reset pre-computation of h
# pre-computation of embedding
eys = self.embed(pad_ys_in) # utt x olen x zdim
eys = F.separate(eys, axis=1)
# loop for an output sequence
for i in six.moves.range(olength):
att_c, att_w = self.att(hs, z_list[0], att_w)
ey = F.hstack((eys[i], att_c)) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
att_ws.append(att_w) # for debugging
att_ws = F.stack(att_ws, axis=1)
att_ws.to_cpu()
return att_ws.data
[docs]def decoder_for(args, odim, sos, eos, att, labeldist):
"""Return the decoding layer corresponding to the args.
Args:
args (Namespace): The program arguments.
odim (int): The output dimension.
sos (int): Number to indicate the start of sequences.
eos (int) Number to indicate the end of sequences.
att (Module): Attention module defined at `espnet.nets.chainer_backend.attentions`.
labeldist (numpy.array): Distributed array of length od transcript.
Returns:
chainer.Chain: The decoder module.
"""
return Decoder(args.eprojs, odim, args.dtype, args.dlayers, args.dunits, sos, eos, att, args.verbose,
args.char_list, labeldist,
args.lsm_weight, args.sampling_probability)