from distutils.version import LooseVersion
import logging
import random
import six
import numpy as np
import torch
import torch.nn.functional as F
from argparse import Namespace
from espnet.nets.ctc_prefix_score import CTCPrefixScore
from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy
from espnet.nets.pytorch_backend.nets_utils import mask_by_length
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet.nets.scorer_interface import ScorerInterface
MAX_DECODER_OUTPUT = 5
CTC_SCORING_RATIO = 1.5
[docs]class Decoder(torch.nn.Module, ScorerInterface):
"""Decoder module
:param int eprojs: # encoder projection units
:param int odim: dimension of outputs
:param str dtype: gru or lstm
:param int dlayers: # decoder layers
:param int dunits: # decoder units
:param int sos: start of sequence symbol id
:param int eos: end of sequence symbol id
:param torch.nn.Module att: attention module
:param int verbose: verbose level
:param list char_list: list of character strings
:param ndarray labeldist: distribution of label smoothing
:param float lsm_weight: label smoothing weight
:param float sampling_probability: scheduled sampling probability
:param float dropout: dropout rate
:param float context_residual: if True, use context vector for token generation
:param float replace_sos: use for multilingual (speech/text) translation
"""
def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, eos, att, verbose=0,
char_list=None, labeldist=None, lsm_weight=0., sampling_probability=0.0,
dropout=0.0, context_residual=False, replace_sos=False):
torch.nn.Module.__init__(self)
self.dtype = dtype
self.dunits = dunits
self.dlayers = dlayers
self.context_residual = context_residual
self.embed = torch.nn.Embedding(odim, dunits)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(dunits + eprojs, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits + eprojs,
dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in six.moves.range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(dunits, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits, dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
self.ignore_id = -1
if context_residual:
self.output = torch.nn.Linear(dunits + eprojs, odim)
else:
self.output = torch.nn.Linear(dunits, odim)
self.loss = None
self.att = att
self.dunits = dunits
self.sos = sos
self.eos = eos
self.odim = odim
self.verbose = verbose
self.char_list = char_list
# for label smoothing
self.labeldist = labeldist
self.vlabeldist = None
self.lsm_weight = lsm_weight
self.sampling_probability = sampling_probability
self.dropout = dropout
# for multilingual translation
self.replace_sos = replace_sos
self.logzero = -10000000000.0
[docs] def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
[docs] def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in six.moves.range(1, self.dlayers):
z_list[l], c_list[l] = self.decoder[l](
self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l]))
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for l in six.moves.range(1, self.dlayers):
z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l])
return z_list, c_list
[docs] def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, tgt_lang_ids=None):
"""Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index indicates the index of decoding stream.
:param torch.Tensor tgt_lang_ids: batch of target language id tensor (B, 1)
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy
:rtype: float
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# attention index for the attention module
# in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att) - 1)
# hlen should be list of integer
hlens = list(map(int, hlens))
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(tgt_lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get dim, length info
batch = ys_out_pad.size(0)
olength = ys_out_pad.size(1)
logging.info(self.__class__.__name__ + ' input lengths: ' + str(hlens))
logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out]))
# initialization
c_list = [self.zero_state(hs_pad)]
z_list = [self.zero_state(hs_pad)]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad))
z_list.append(self.zero_state(hs_pad))
att_w = None
z_all = []
self.att[att_idx].reset() # reset pre-computation of h
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
att_c, att_w = self.att[att_idx](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w)
if i > 0 and random.random() < self.sampling_probability:
logging.info(' scheduled sampling ')
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
# compute loss
y_all = self.output(z_all)
if LooseVersion(torch.__version__) < LooseVersion('1.0'):
reduction_str = 'elementwise_mean'
else:
reduction_str = 'mean'
self.loss = F.cross_entropy(y_all, ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction=reduction_str)
# -1: eos, which is removed in the loss computation
self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))
# compute perplexity
ppl = np.exp(self.loss.item() * np.mean([len(x) for x in ys_in]) / np.sum([len(x) for x in ys_in]))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
ys_true.detach().cpu().numpy()):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist))
loss_reg = - torch.sum((F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in)
self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc, ppl
[docs] def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
"""beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs)
:param torch.Tensor lpz: ctc log softmax output (T, odim)
:param Namespace recog_args: argument Namespace containing options
:param char_list: list of character strings
:param torch.nn.Module rnnlm: language module
:param int strm_idx: stream index for speaker parallel attention in multi-speaker case
:return: N-best decoding results
:rtype: list of dicts
"""
logging.info('input lengths: ' + str(h.size(0)))
att_idx = min(strm_idx, len(self.att) - 1)
# initialization
c_list = [self.zero_state(h.unsqueeze(0))]
z_list = [self.zero_state(h.unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(h.unsqueeze(0)))
z_list.append(self.zero_state(h.unsqueeze(0)))
a = None
self.att[att_idx].reset() # reset pre-computation of h
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
# preprate sos
if self.replace_sos and recog_args.tgt_lang:
y = char_list.index(recog_args.tgt_lang)
else:
y = self.sos
logging.info('<sos> index: ' + str(y))
logging.info('<sos> mark: ' + char_list[y])
vy = h.new_zeros(1).long()
if recog_args.maxlenratio == 0:
maxlen = h.shape[0]
else:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
minlen = int(recog_args.minlenratio * h.size(0))
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list,
'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
if lpz is not None:
ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, np)
hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
hyp['ctc_score_prev'] = 0.0
if ctc_weight != 1.0:
# pre-pruning based on attention scores
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz.shape[-1]
hyps = [hyp]
ended_hyps = []
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
vy.unsqueeze(1)
vy[0] = hyp['yseq'][i]
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
ey.unsqueeze(0)
att_c, att_w = self.att[att_idx](h.unsqueeze(0), [h.size(0)],
self.dropout_dec[0](hyp['z_prev'][0]), hyp['a_prev'])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev'])
# get nbest local scores and their ids
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_att_scores = F.log_softmax(logits, dim=1)
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1)
ctc_scores, ctc_states = ctc_prefix_score(
hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
+ ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
# [:] is needed!
new_hyp['z_prev'] = z_list[:]
new_hyp['c_prev'] = c_list[:]
new_hyp['a_prev'] = att_w[:]
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz is not None:
new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]]
new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypotheses: ' + str(len(hyps)))
logging.debug(
'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last position in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remaining hypotheses: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypotheses: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
# remove sos
return nbest_hyps
[docs] def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None,
normalize_score=True, strm_idx=0, tgt_lang_ids=None):
logging.info('input lengths: ' + str(h.size(1)))
att_idx = min(strm_idx, len(self.att) - 1)
h = mask_by_length(h, hlens, 0.0)
# search params
batch = len(hlens)
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
att_weight = 1.0 - ctc_weight
ctc_margin = recog_args.ctc_window_margin
n_bb = batch * beam
pad_b = to_device(self, torch.arange(batch) * beam).view(-1, 1)
max_hlen = int(max(hlens))
if recog_args.maxlenratio == 0:
maxlen = max_hlen
else:
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
minlen = int(recog_args.minlenratio * max_hlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialization
c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
vscores = to_device(self, torch.zeros(batch, beam))
a_prev = None
rnnlm_state = None
ctc_scorer = None
ctc_state = None
self.att[att_idx].reset() # reset pre-computation of h
if self.replace_sos and recog_args.tgt_lang:
logging.info('<sos> index: ' + str(char_list.index(recog_args.tgt_lang)))
logging.info('<sos> mark: ' + recog_args.tgt_lang)
yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
elif tgt_lang_ids is not None:
# NOTE: used for evaluation during training
yseq = [[tgt_lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
else:
logging.info('<sos> index: ' + str(self.sos))
logging.info('<sos> mark: ' + char_list[self.sos])
yseq = [[self.sos] for _ in six.moves.range(n_bb)]
accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
stop_search = [False for _ in six.moves.range(batch)]
nbest_hyps = [[] for _ in six.moves.range(batch)]
ended_hyps = [[] for _ in range(batch)]
exp_hlens = hlens.repeat(beam).view(beam, batch).transpose(0, 1).contiguous()
exp_hlens = exp_hlens.view(-1).tolist()
exp_h = h.unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
exp_h = exp_h.view(n_bb, h.size()[1], h.size()[2])
if lpz is not None:
scoring_ratio = CTC_SCORING_RATIO if att_weight > 0.0 and not lpz.is_cuda else 0
ctc_scorer = CTCPrefixScoreTH(lpz, hlens, 0, self.eos, beam,
scoring_ratio, margin=ctc_margin)
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq)))
ey = self.dropout_emb(self.embed(vy))
att_c, att_w = self.att[att_idx](exp_h, exp_hlens, self.dropout_dec[0](z_prev[0]), a_prev)
ey = torch.cat((ey, att_c), dim=1)
# attention decoder
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_scores = att_weight * F.log_softmax(logits, dim=1)
# rnnlm
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
local_scores = local_scores + recog_args.lm_weight * local_lm_scores
# ctc
if ctc_scorer:
att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
ctc_state, local_ctc_scores = ctc_scorer(yseq, ctc_state, local_scores, att_w_)
local_scores = local_scores + ctc_weight * local_ctc_scores
local_scores = local_scores.view(batch, beam, self.odim)
if i == 0:
local_scores[:, 1:, :] = self.logzero
# accumulate scores
eos_vscores = local_scores[:, :, self.eos] + vscores
vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
vscores[:, :, self.eos] = self.logzero
vscores = (vscores + local_scores).view(batch, -1)
# global pruning
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist()
y_prev = yseq[:][:]
yseq = self._index_select_list(yseq, accum_padded_beam_ids)
yseq = self._append_ids(yseq, accum_odim_ids)
vscores = accum_best_scores
vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids))
if isinstance(att_w, torch.Tensor):
a_prev = torch.index_select(att_w.view(n_bb, *att_w.shape[1:]), 0, vidx)
elif isinstance(att_w, list):
# handle the case of multi-head attention
a_prev = [torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w]
else:
# handle the case of location_recurrent when return is a tuple
a_prev_ = torch.index_select(att_w[0].view(n_bb, -1), 0, vidx)
h_prev_ = torch.index_select(att_w[1][0].view(n_bb, -1), 0, vidx)
c_prev_ = torch.index_select(att_w[1][1].view(n_bb, -1), 0, vidx)
a_prev = (a_prev_, (h_prev_, c_prev_))
z_prev = [torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
c_prev = [torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
if rnnlm:
rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
if ctc_scorer:
ctc_state = ctc_scorer.index_select_state(ctc_state, accum_best_ids)
# pick ended hyps
if i > minlen:
k = 0
penalty_i = (i + 1) * penalty
thr = accum_best_scores[:, -1]
for samp_i in six.moves.range(batch):
if stop_search[samp_i]:
k = k + beam
continue
for beam_j in six.moves.range(beam):
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
yk = y_prev[k][:]
yk.append(self.eos)
if len(yk) < hlens[samp_i]:
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
_score = _vscore.data.cpu().numpy()
ended_hyps[samp_i].append({'yseq': yk, 'vscore': _vscore, 'score': _score})
k = k + 1
# end detection
stop_search = [stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
for samp_i in six.moves.range(batch)]
stop_search_summary = list(set(stop_search))
if len(stop_search_summary) == 1 and stop_search_summary[0]:
break
torch.cuda.empty_cache()
dummy_hyps = [{'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')])}]
ended_hyps = [ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
for samp_i in six.moves.range(batch)]
if normalize_score:
for samp_i in six.moves.range(batch):
for x in ended_hyps[samp_i]:
x['score'] /= len(x['yseq'])
nbest_hyps = [sorted(ended_hyps[samp_i], key=lambda x: x['score'],
reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)]
for samp_i in six.moves.range(batch)]
return nbest_hyps
[docs] def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, tgt_lang_ids=None):
"""Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index for parallel speaker attention in multi-speaker case
:param torch.Tensor tgt_lang_ids: batch of target language id tensor (B, 1)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
att_idx = min(strm_idx, len(self.att) - 1)
# hlen should be list of integer
hlen = list(map(int, hlen))
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(tgt_lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get length info
olength = ys_out_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad)]
z_list = [self.zero_state(hs_pad)]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad))
z_list.append(self.zero_state(hs_pad))
att_w = None
att_ws = []
self.att[att_idx].reset() # reset pre-computation of h
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
att_c, att_w = self.att[att_idx](hs_pad, hlen, self.dropout_dec[0](z_list[0]), att_w)
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
att_ws.append(att_w)
# convert to numpy array with the shape (B, Lmax, Tmax)
att_ws = att_to_numpy(att_ws, self.att[att_idx])
return att_ws
@staticmethod
def _get_last_yseq(exp_yseq):
last = []
for y_seq in exp_yseq:
last.append(y_seq[-1])
return last
@staticmethod
def _append_ids(yseq, ids):
if isinstance(ids, list):
for i, j in enumerate(ids):
yseq[i].append(j)
else:
for i in range(len(yseq)):
yseq[i].append(ids)
return yseq
@staticmethod
def _index_select_list(yseq, lst):
new_yseq = []
for l in lst:
new_yseq.append(yseq[l][:])
return new_yseq
@staticmethod
def _index_select_lm_state(rnnlm_state, dim, vidx):
if isinstance(rnnlm_state, dict):
new_state = {}
for k, v in rnnlm_state.items():
new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
elif isinstance(rnnlm_state, list):
new_state = []
for i in vidx:
new_state.append(rnnlm_state[int(i)][:])
return new_state
# scorer interface methods
[docs] def init_state(self, x):
c_list = [self.zero_state(x.unsqueeze(0))]
z_list = [self.zero_state(x.unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(x.unsqueeze(0)))
z_list.append(self.zero_state(x.unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att) - 1)
self.att[att_idx].reset() # reset pre-computation of h
return dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=None, workspace=(att_idx, z_list, c_list))
[docs] def score(self, yseq, state, x):
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
att_c, att_w = self.att[att_idx](
x.unsqueeze(0), [x.size(0)],
self.dropout_dec[0](state['z_prev'][0]), state['a_prev'])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, state['z_prev'], state['c_prev'])
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return logp, dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=att_w, workspace=(att_idx, z_list, c_list))
[docs]def decoder_for(args, odim, sos, eos, att, labeldist):
return Decoder(args.eprojs, odim, args.dtype, args.dlayers, args.dunits, sos, eos, att, args.verbose,
args.char_list, labeldist,
args.lsm_weight, args.sampling_probability, args.dropout_rate_decoder,
getattr(args, "context_residual", False), # use getattr to keep compatibility
getattr(args, "replace_sos", False)) # use getattr to keep compatibility