Source code for espnet.nets.chainer_backend.e2e_asr

#!/usr/bin/env python

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


import logging
import math

import chainer
from chainer import reporter
import numpy as np

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.chainer_backend.ctc import ctc_for
from espnet.nets.chainer_backend.rnn.attentions import att_for
from espnet.nets.chainer_backend.rnn.decoders import decoder_for
from espnet.nets.chainer_backend.rnn.encoders import encoder_for
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2E_pytorch

CTC_LOSS_THRESHOLD = 10000


[docs]class E2E(ASRInterface, chainer.Chain): """E2E module for chainer backend. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (parser.args): Training config. flag_return (bool): If True, train() would return additional metrics in addition to the training loss. """
[docs] @staticmethod def add_arguments(parser): return E2E_pytorch.add_arguments(parser)
def __init__(self, idim, odim, args, flag_return=True): chainer.Chain.__init__(self) self.mtlalpha = args.mtlalpha assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info # +1 means input (+1) and layers outputs (args.elayer) subsample = np.ones(args.elayers + 1, dtype=np.int) if args.etype.endswith("p") and not args.etype.startswith("vgg"): ss = args.subsample.split("_") for j in range(min(args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample # label smoothing info if args.lsm_type: logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None with self.init_scope(): # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) self.acc = None self.loss = None self.flag_return = flag_return
[docs] def forward(self, xs, ilens, ys): """E2E forward propagation. Args: xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) ilens (chainer.Variable): Batch of length of each input batch. (B,) ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim) Returns: float: Loss that calculated by attention and ctc loss. float (optional): Ctc loss. float (optional): Attention loss. float (optional): Accuracy. """ # 1. encoder hs, ilens = self.enc(xs, ilens) # 3. CTC loss if self.mtlalpha == 0: loss_ctc = None else: loss_ctc = self.ctc(hs, ys) # 4. attention loss if self.mtlalpha == 1: loss_att = None acc = None else: loss_att, acc = self.dec(hs, ys) self.acc = acc alpha = self.mtlalpha if alpha == 0: self.loss = loss_att elif alpha == 1: self.loss = loss_ctc else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att if self.loss.data < CTC_LOSS_THRESHOLD and not math.isnan(self.loss.data): reporter.report({'loss_ctc': loss_ctc}, self) reporter.report({'loss_att': loss_att}, self) reporter.report({'acc': acc}, self) logging.info('mtl loss:' + str(self.loss.data)) reporter.report({'loss': self.loss}, self) else: logging.warning('loss (=%f) is not correct', self.loss.data) if self.flag_return: return self.loss, loss_ctc, loss_att, acc else: return self.loss
[docs] def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E greedy/beam search. Args: x (chainer.Variable): Input tensor for recognition. recog_args (parser.args): Arguments of config file. char_list (List[str]): List of Charactors. rnnlm (Module): RNNLM module defined at `espnet.lm.chainer_backend.lm`. Returns: List[Dict[str, Any]]: Result of recognition. """ # subsample frame x = x[::self.subsample[0], :] ilen = self.xp.array(x.shape[0], dtype=np.int32) h = chainer.Variable(self.xp.array(x, dtype=np.float32)) with chainer.no_backprop_mode(), chainer.using_config('train', False): # 1. encoder # make a utt list (1) to use the same interface for encoder h, _ = self.enc([h], [ilen]) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(h).data[0] else: lpz = None # 2. decoder # decode the first utterance y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm) return y
[docs] def calculate_all_attentions(self, xs, ilens, ys): """E2E attention calculation. Args: xs (List): List of padded input sequences. [(T1, idim), (T2, idim), ...] ilens (np.ndarray): Batch of lengths of input sequences. (B) ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...] Returns: float np.ndarray: Attention weights. (B, Lmax, Tmax) """ hs, ilens = self.enc(xs, ilens) att_ws = self.dec.calculate_all_attentions(hs, ys) return att_ws