Source code for espnet.nets.pytorch_backend.e2e_asr_mix

#!/usr/bin/env python

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)


from __future__ import division

import argparse
import logging
import math
import sys

import chainer
from chainer import reporter
import editdistance
import numpy as np
import six
import torch

from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import get_vgg2l_odim
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import RNNP
from espnet.nets.pytorch_backend.rnn.encoders import VGG2L

CTC_LOSS_THRESHOLD = 10000


[docs]class Reporter(chainer.Chain): """A chainer reporter wrapper"""
[docs] def report(self, loss_ctc, loss_att, acc, cer, wer, mtl_loss): reporter.report({'loss_ctc': loss_ctc}, self) reporter.report({'loss_att': loss_att}, self) reporter.report({'acc': acc}, self) reporter.report({'cer': cer}, self) reporter.report({'wer': wer}, self) logging.info('mtl loss:' + str(mtl_loss)) reporter.report({'loss': mtl_loss}, self)
[docs]class PIT(object): """Permutation Invariant Training (PIT) module :parameter int num_spkrs: number of speakers for PIT process (2 or 3) """ def __init__(self, num_spkrs): self.num_spkrs = num_spkrs if self.num_spkrs == 2: self.perm_choices = [[0, 1], [1, 0]] elif self.num_spkrs == 3: self.perm_choices = [[0, 1, 2], [0, 2, 1], [1, 2, 0], [1, 0, 2], [2, 0, 1], [2, 1, 0]] else: raise ValueError
[docs] def min_pit_sample(self, loss): """PIT min_pit_sample :param 1-D torch.Tensor loss: list of losses for one sample, including [h1r1, h1r2, h2r1, h2r2] or [h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3] :return min_loss :rtype torch.Tensor (1) :return permutation :rtype List: len=2 """ if self.num_spkrs == 2: score_perms = torch.stack([loss[0] + loss[3], loss[1] + loss[2]]) / self.num_spkrs elif self.num_spkrs == 3: score_perms = torch.stack([loss[0] + loss[4] + loss[8], loss[0] + loss[5] + loss[7], loss[1] + loss[5] + loss[6], loss[1] + loss[3] + loss[8], loss[2] + loss[3] + loss[7], loss[2] + loss[4] + loss[6]]) / self.num_spkrs perm_loss, min_idx = torch.min(score_perms, 0) permutation = self.perm_choices[min_idx] return perm_loss, permutation
[docs] def pit_process(self, losses): """PIT pit_process :param torch.Tensor losses: losses (B, 1|4|9) :return pit_loss :rtype torch.Tensor (B) :return permutation :rtype torch.LongTensor (B, 1|2|3) """ bs = losses.size(0) ret = [self.min_pit_sample(losses[i]) for i in range(bs)] loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) # (B) permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device) return torch.mean(loss_perm), permutation
[docs]class E2E(ASRInterface, torch.nn.Module): """E2E module :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ def __init__(self, idim, odim, args): torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir self.reporter = Reporter() self.num_spkrs = args.num_spkrs self.spa = args.spa self.pit = PIT(self.num_spkrs) # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info # +1 means input (+1) and layers outputs (args.elayer_sd + args.elayers) subsample = np.ones(args.elayers_sd + args.elayers + 1, dtype=np.int) if args.etype.endswith("p") and not args.etype.startswith("vgg"): ss = args.subsample.split("_") for j in range(min(args.elayers_sd + args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample # label smoothing info if args.lsm_type: logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim, reduce=False) # attention num_att = self.num_spkrs if args.spa else 1 self.att = att_for(args, num_att) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # weight initialization self.init_like_chainer() # options for beam search if 'report_cer' in vars(args) and (args.report_cer or args.report_wer): recog_args = {'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank} self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
[docs] def init_like_chainer(self): """Initialize weight like chainer chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) """ def lecun_normal_init_parameters(module): for p in module.parameters(): data = p.data if data.dim() == 1: # bias data.zero_() elif data.dim() == 2: # linear weight n = data.size(1) stdv = 1. / math.sqrt(n) data.normal_(0, stdv) elif data.dim() == 4: # conv weight n = data.size(1) for k in data.size()[2:]: n *= k stdv = 1. / math.sqrt(n) data.normal_(0, stdv) else: raise NotImplementedError def set_forget_bias_to_one(bias): n = bias.size(0) start, end = n // 4, n // 2 bias.data[start:end].fill_(1.) lecun_normal_init_parameters(self) # exceptions # embed weight ~ Normal(0, 1) self.dec.embed.weight.data.normal_(0, 1) # forget-bias = 1.0 # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 for l in six.moves.range(len(self.dec.decoder)): set_forget_bias_to_one(self.dec.decoder[l].bias_ih)
[docs] def forward(self, xs_pad, ilens, ys_pad_sd): """E2E forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad_sd: batch of padded character id sequence tensor (B, num_spkrs, Lmax) :return: ctc loss value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. encoder hs_pad_sd, hlens = self.enc(xs_pad, ilens) # 2. CTC loss ys_pad_sd = ys_pad_sd.transpose(0, 1) # (num_spkrs, B, Lmax) if self.mtlalpha == 0: loss_ctc, min_perm = None, None elif self.num_spkrs <= 3: loss_ctc_perm = torch.stack([self.ctc(hs_pad_sd[i // self.num_spkrs], hlens, ys_pad_sd[i % self.num_spkrs]) for i in range(self.num_spkrs ** 2)], dim=1) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) logging.info('ctc loss:' + str(float(loss_ctc))) # 3. attention loss if self.mtlalpha == 1: loss_att = None acc = None else: for i in range(ys_pad_sd.size(1)): # B ys_pad_sd[:, i] = ys_pad_sd[min_perm[i], i] rslt = [self.dec(hs_pad_sd[i], hlens, ys_pad_sd[i], strm_idx=i) for i in range(self.num_spkrs)] loss_att = sum([r[0] for r in rslt]) / float(len(rslt)) acc = sum([r[1] for r in rslt]) / float(len(rslt)) self.acc = acc # 5. compute cer/wer if self.training or not (self.report_cer or self.report_wer): cer, wer = 0.0, 0.0 # oracle_cer, oracle_wer = 0.0, 0.0 else: if self.recog_args.ctc_weight > 0.0: lpz_sd = [self.ctc.log_softmax(hs_pad_sd[i]).data for i in range(self.num_spkrs)] else: lpz_sd = None word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], [] nbest_hyps_sd = [self.dec.recognize_beam_batch(hs_pad_sd[i], torch.tensor(hlens), lpz_sd[i], self.recog_args, self.char_list, self.rnnlm, strm_idx=i) for i in range(self.num_spkrs)] # remove <sos> and <eos> y_hats_sd = [[nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps_sd[i]] for i in range(self.num_spkrs)] for i in range(len(y_hats_sd[0])): hyp_words = [] hyp_chars = [] ref_words = [] ref_chars = [] for ns in range(self.num_spkrs): y_hat = y_hats_sd[ns][i] y_true = ys_pad_sd[ns][i] seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ') seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '') seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ') hyp_words.append(seq_hat_text.split()) ref_words.append(seq_true_text.split()) hyp_chars.append(seq_hat_text.replace(' ', '')) ref_chars.append(seq_true_text.replace(' ', '')) tmp_word_ed = [editdistance.eval(hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs]) for ns in range(self.num_spkrs ** 2)] # h1r1,h1r2,h2r1,h2r2 tmp_char_ed = [editdistance.eval(hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs]) for ns in range(self.num_spkrs ** 2)] # h1r1,h1r2,h2r1,h2r2 word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0]) word_ref_lens.append(len(sum(ref_words, []))) char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0]) char_ref_lens.append(len(''.join(ref_chars))) wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens) cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens) alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) loss_data = float(self.loss) if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): self.reporter.report(loss_ctc_data, loss_att_data, acc, cer, wer, loss_data) else: logging.warning('loss (=%f) is not correct', loss_data) return self.loss
[docs] def recognize(self, x, recog_args, char_list, rnnlm=None): """E2E beam search :param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() # subsample frame x = x[::self.subsample[0], :] ilen = [x.shape[0]] h = to_device(self, torch.from_numpy( np.array(x, dtype=np.float32))) # 1. encoder # make a utt list (1) to use the same interface for encoder h = h.contiguous() h_sd, _ = self.enc(h.unsqueeze(0), ilen) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz_sd = [self.ctc.log_softmax(i)[0] for i in h_sd] else: lpz_sd = None # 2. decoder # decode the first utterance y = [self.dec.recognize_beam(h_sd[i][0], lpz_sd[i], recog_args, char_list, rnnlm, strm_idx=i) for i in range(self.num_spkrs)] if prev: self.train() return y
[docs] def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): """E2E beam search :param ndarray xs: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ prev = self.training self.eval() # subsample frame xs = [xx[::self.subsample[0], :] for xx in xs] ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) hs = [to_device(self, torch.from_numpy(np.array(xx, dtype=np.float32))) for xx in xs] # 1. encoder xpad = pad_list(hs, 0.0) hpad_sd, hlens = self.enc(xpad, ilens) # calculate log P(z_t|X) for CTC scores if recog_args.ctc_weight > 0.0: lpz_sd = [self.ctc.log_softmax(hpad_sd[i]) for i in range(self.num_spkrs)] normalize_score = False else: lpz_sd = None normalize_score = True # 2. decoder y = [self.dec.recognize_beam_batch(hpad_sd[i], hlens, lpz_sd[i], recog_args, char_list, rnnlm, normalize_score=normalize_score, strm_idx=i) for i in range(self.num_spkrs)] if prev: self.train() return y
[docs] def calculate_all_attentions(self, xs_pad, ilens, ys_pad_sd): """E2E attention calculation :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad_sd: batch of padded character id sequence tensor (B, num_spkrs, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray """ with torch.no_grad(): # encoder hpad_sd, hlens = self.enc(xs_pad, ilens) # Permutation ys_pad_sd = ys_pad_sd.transpose(0, 1) # (num_spkrs, B, Lmax) if self.num_spkrs <= 3: loss_ctc = torch.stack([self.ctc(hpad_sd[i // self.num_spkrs], hlens, ys_pad_sd[i % self.num_spkrs]) for i in range(self.num_spkrs ** 2)], 1) # (B, num_spkrs^2) loss_ctc, min_perm = self.pit.pit_process(loss_ctc) for i in range(ys_pad_sd.size(1)): # B ys_pad_sd[:, i] = ys_pad_sd[min_perm[i], i] # decoder att_ws_sd = [self.dec.calculate_all_attentions(hpad_sd[i], hlens, ys_pad_sd[i], strm_idx=i) for i in range(self.num_spkrs)] return att_ws_sd
[docs]class Encoder(torch.nn.Module): """Encoder module :param str etype: type of encoder network :param int idim: number of dimensions of encoder network :param int elayers_sd: number of layers of speaker differentiate part in encoder network :param int elayers_rec: number of layers of shared recognition part in encoder network :param int eunits: number of lstm units of encoder network :param int eprojs: number of projection units of encoder network :param np.ndarray subsample: list of subsampling numbers :param float dropout: dropout rate :param int in_channel: number of input channels :param int num_spkrs: number of number of speakers """ def __init__(self, etype, idim, elayers_sd, elayers_rec, eunits, eprojs, subsample, dropout, num_spkrs=2, in_channel=1): super(Encoder, self).__init__() typ = etype.lstrip("vgg").lstrip("b").rstrip("p") if typ != "lstm" and typ != "gru": logging.error("Error: need to specify an appropriate encoder architecture") if etype.startswith("vgg"): if etype[-1] == "p": self.enc_mix = torch.nn.ModuleList([VGG2L(in_channel)]) self.enc_sd = torch.nn.ModuleList([torch.nn.ModuleList([RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers_sd, eunits, eprojs, subsample[:elayers_sd + 1], dropout, typ=typ)]) for i in range(num_spkrs)]) self.enc_rec = torch.nn.ModuleList([RNNP(eprojs, elayers_rec, eunits, eprojs, subsample[elayers_sd:], dropout, typ=typ)]) logging.info('Use CNN-VGG + B' + typ.upper() + 'P for encoder') else: logging.error( "Error: need to specify an appropriate encoder architecture") sys.exit() self.num_spkrs = num_spkrs
[docs] def forward(self, xs_pad, ilens): """Encoder forward :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)] :rtype: torch.Tensor """ # mixture encoder for module in self.enc_mix: xs_pad, ilens, _ = module(xs_pad, ilens) # SD and Rec encoder xs_pad_sd = [xs_pad for i in range(self.num_spkrs)] ilens_sd = [ilens for i in range(self.num_spkrs)] for ns in range(self.num_spkrs): # Encoder_SD: speaker differentiate encoder for module in self.enc_sd[ns]: xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) # Encoder_Rec: recognition encoder for module in self.enc_rec: xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) # make mask to remove bias value in padded part mask = to_device(self, make_pad_mask(ilens_sd[0]).unsqueeze(-1)) return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd[0]
[docs]def encoder_for(args, idim, subsample): return Encoder(args.etype, idim, args.elayers_sd, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate, args.num_spkrs)