Source code for espnet.nets.pytorch_backend.e2e_asr_transducer

#!/usr/bin/env python

import argparse
import logging
import math

import editdistance

import chainer
import numpy as np
import six
import torch

from chainer import reporter

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders_transducer import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for

from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor

from espnet.utils.cli_utils import strtobool


[docs]class Reporter(chainer.Chain): """A chainer reporter wrapper"""
[docs] def report(self, loss, cer, wer): reporter.report({'cer': cer}, self) reporter.report({'wer': wer}, self) logging.info('loss:' + str(loss)) reporter.report({'loss': loss}, self)
[docs]class E2E(ASRInterface, torch.nn.Module): """E2E module Args: idim (int): dimension of inputs odim (int): dimension of outputs args (namespace): argument Namespace containing options """
[docs] @staticmethod def add_arguments(parser): group = parser.add_argument_group("transducer model setting") # encoder group.add_argument('--etype', default='blstmp', type=str, choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm', 'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'], help='Type of encoder network architecture') group.add_argument('--elayers', default=4, type=int, help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)') group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') group.add_argument('--eprojs', default=320, type=int, help='Number of encoder projection units') group.add_argument('--subsample', default="1", type=str, help='Subsample input frames x_y_z means subsample every x frame at 1st layer, ' 'every y frame at 2nd layer etc.') # attention group.add_argument('--atype', default='dot', type=str, choices=['noatt', 'dot', 'add', 'location', 'coverage', 'coverage_location', 'location2d', 'location_recurrent', 'multi_head_dot', 'multi_head_add', 'multi_head_loc', 'multi_head_multi_res_loc'], help='Type of attention architecture') group.add_argument('--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--awin', default=5, type=int, help='Window size for location2d attention') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') group.add_argument('--aconv-chans', default=-1, type=int, help='Number of attention convolution channels \ (negative value indicates no location-aware attention)') group.add_argument('--aconv-filts', default=100, type=int, help='Number of attention convolution filters \ (negative value indicates no location-aware attention)') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # decoder group.add_argument('--dtype', default='lstm', type=str, choices=['lstm', 'gru'], help='Type of decoder network architecture') group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') group.add_argument('--dropout-rate-decoder', default=0.0, type=float, help='Dropout rate for the decoder') # prediction group.add_argument('--dec-embed-dim', default=320, type=int, help='Number of decoder embeddings dimensions') parser.add_argument('--dropout-rate-embed-decoder', default=0.0, type=float, help='Dropout rate for the decoder embeddings') # general group.add_argument('--rnnt_type', default='warp-transducer', type=str, choices=['warp-transducer'], help='Type of transducer implementation to calculate loss.') parser.add_argument('--rnnt-mode', default='rnnt', type=str, choices=['rnnt', 'rnnt-att'], help='RNN-Transducing mode') parser.add_argument('--joint-dim', default=320, type=int, help='Number of dimensions in joint space') # decoding parser.add_argument('--score-norm-transducer', type=strtobool, nargs='?', default=True, help='Normalize transducer scores by length')
def __init__(self, idim, odim, args): super(E2E, self).__init__() torch.nn.Module.__init__(self) self.rnnt_mode = args.rnnt_mode self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # note that eos is the same as sos (equivalent ID) self.sos = odim - 1 self.eos = odim - 1 # subsample info # +1 means input (+1) and layers outputs (args.elayer) subsample = np.ones(args.elayers + 1, dtype=np.int) if args.etype.endswith("p") and not args.etype.startswith("vgg"): ss = args.subsample.split("_") for j in range(min(args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample if args.use_frontend: # Relative importing because of using python3 syntax from espnet.nets.pytorch_backend.frontends.feature_transform \ import feature_transform_for from espnet.nets.pytorch_backend.frontends.frontend \ import frontend_for self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for(args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) if args.rnnt_mode == 'rnnt-att': # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.att) else: # prediction self.dec = decoder_for(args, odim) # weight initialization self.init_like_chainer() # options for beam search if 'report_cer' in vars(args) and (args.report_cer or args.report_wer): recog_args = {'beam_size': args.beam_size, 'nbest': args.nbest, 'space': args.sym_space, 'score_norm_transducer': args.score_norm_transducer} self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.logzero = -10000000000.0 self.rnnlm = None self.loss = None
[docs] def init_like_chainer(self): """Initialize weight like chainer chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ def lecun_normal_init_parameters(module): for p in module.parameters(): data = p.data if data.dim() == 1: # bias data.zero_() elif data.dim() == 2: # linear weight n = data.size(1) stdv = 1. / math.sqrt(n) data.normal_(0, stdv) elif data.dim() == 4: # conv weight n = data.size(1) for k in data.size()[2:]: n *= k stdv = 1. / math.sqrt(n) data.normal_(0, stdv) else: raise NotImplementedError def set_forget_bias_to_one(bias): n = bias.size(0) start, end = n // 4, n // 2 bias.data[start:end].fill_(1.) lecun_normal_init_parameters(self) if self.rnnt_mode == 'rnnt-att': # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for l in range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[l].bias_ih) else: self.dec.embed.weight.data.normal_(0, 1)
[docs] def forward(self, xs_pad, ilens, ys_pad): """E2E forward Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: loss (torch.Tensor): transducer loss value """ # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # 1. encoder hs_pad, hlens, _ = self.enc(hs_pad, hlens) # 2. decoder loss = self.dec(hs_pad, hlens, ys_pad) # 3. compute cer/wer # note: not recommended outside debugging right now, # the training time is hugely impacted. if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 else: word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] batchsize = int(hs_pad.size(0)) batch_nbest = [] for b in six.moves.range(batchsize): nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args) batch_nbest.append(nbest_hyps) y_hats = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest] for i, y_hat in enumerate(y_hats): y_true = ys_pad[i] seq_hat = [self.char_list[int(idx)] for idx in y_hat] seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ') seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ') hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() word_eds.append(editdistance.eval(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) hyp_chars = seq_hat_text.replace(' ', '') ref_chars = seq_true_text.replace(' ', '') char_eds.append(editdistance.eval(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) self.loss = loss loss_data = float(self.loss) if not math.isnan(loss_data): self.reporter.report(loss_data, cer, wer) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
[docs] def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E recognize Args: x (ndarray): input acoustic feature (T, D) recog_args (namespace): argument Namespace containing options char_list (list): list of characters rnnlm (torch.nn.Module): language model module Returns: y (list): n-best decoding results """ prev = self.training self.eval() ilens = [x.shape[0]] # subsample frame x = x[::self.subsample[0], :] h = to_device(self, to_torch_tensor(x).float()) # make a utt list (1) to use the same interface for encoder hs = h.contiguous().unsqueeze(0) # 0. Frontend if self.frontend is not None: enhanced, hlens, mask = self.frontend(hs, ilens) hs, hlens = self.feature_transform(enhanced, hlens) else: hs, hlens = hs, ilens # 1. Encoder h, _, _ = self.enc(hs, hlens) # 2. Decoder if recog_args.beam_size == 1: y = self.dec.recognize(h[0], recog_args) else: y = self.dec.recognize_beam(h[0], recog_args, rnnlm) if prev: self.train() return y
[docs] def enhance(self, xs): """Forwarding only the frontend stage Args: xs (ndarray): input acoustic feature (T, C, F) Returns: enhanced (ndarray): mask (torch.Tensor): ilens (torch.Tensor): batch of lengths of input sequences (B) """ if self.frontend is None: raise RuntimeError('Frontend does\'t exist') prev = self.training self.eval() ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] xs_pad = pad_list(xs, 0.0) enhanced, hlensm, mask = self.frontend(xs_pad, ilens) if prev: self.train() return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
[docs] def calculate_all_attentions(self, xs_pad, ilens, ys_pad): """E2E attention calculation Args: xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) ilens (torch.Tensor): batch of lengths of input sequences (B) ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax) Returns: att_ws (ndarray): attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). """ if self.rnnt_mode == 'rnnt': return [] with torch.no_grad(): # 0. Frontend if self.frontend is not None: hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) hs_pad, hlens = self.feature_transform(hs_pad, hlens) else: hs_pad, hlens = xs_pad, ilens # encoder hpad, hlens, _ = self.enc(hs_pad, hlens) # decoder att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) return att_ws
[docs] def subsample_frames(self, x): # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy( np.array(x, dtype=np.float32))) h.contiguous() return h, ilen