#!/usr/bin/env python
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from __future__ import division
import logging
import math
import os
import chainer
import numpy as np
import six
import torch
from chainer import reporter
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import to_device
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for
[docs]class Reporter(chainer.Chain):
"""A chainer reporter wrapper"""
[docs] def report(self, loss, acc, ppl):
reporter.report({'loss': loss}, self)
reporter.report({'acc': acc}, self)
reporter.report({'ppl': ppl}, self)
[docs]class E2E(MTInterface, torch.nn.Module):
"""E2E module
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
def __init__(self, idim, odim, args):
super(E2E, self).__init__()
torch.nn.Module.__init__(self)
self.etype = args.etype
self.verbose = args.verbose
self.char_list = args.char_list
self.outdir = args.outdir
self.reporter = Reporter()
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
self.pad = odim
# subsample info
# +1 means input (+1) and layers outputs (args.elayer)
subsample = np.ones(args.elayers + 1, dtype=np.int)
logging.warning('Subsampling is not performed for machine translation.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
self.subsample = subsample
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json)
else:
labeldist = None
# multilingual related
self.replace_sos = args.replace_sos
# encoder
self.embed_src = torch.nn.Embedding(idim + 1, args.eunits, padding_idx=idim)
# NOTE: +1 means the padding index
self.dropout_emb_src = torch.nn.Dropout(p=args.dropout_rate)
self.enc = encoder_for(args, args.eunits, self.subsample)
# attention
self.att = att_for(args)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
[docs] def init_like_chainer(self):
"""Initialize weight like chainer
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
def lecun_normal_init_parameters(module):
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.zero_()
elif data.dim() == 2:
# linear weight
n = data.size(1)
stdv = 1. / math.sqrt(n)
data.normal_(0, stdv)
elif data.dim() in (3, 4):
# conv weight
n = data.size(1)
for k in data.size()[2:]:
n *= k
stdv = 1. / math.sqrt(n)
data.normal_(0, stdv)
else:
raise NotImplementedError
def set_forget_bias_to_one(bias):
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.)
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for l in six.moves.range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[l].bias_ih)
[docs] def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
# 1. Encoder
xs_pad, ys_pad, tgt_lang_ids = self.target_lang_biasing_train(xs_pad, ilens, ys_pad)
hs_pad, hlens, _ = self.enc(self.dropout_emb_src(self.embed_src(xs_pad)), ilens)
# 3. attention loss
loss, acc, ppl = self.dec(hs_pad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids)
self.acc = acc
self.ppl = ppl
self.loss = loss
loss_data = float(self.loss)
if not math.isnan(loss_data):
self.reporter.report(loss_data, acc, ppl)
else:
logging.warning('loss (=%f) is not correct', loss_data)
return self.loss
[docs] def target_lang_biasing_train(self, xs_pad, ilens, ys_pad):
"""Replace <sos> with target language IDs for multilingual MT during training.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: source text without language IDs
:rtype: torch.Tensor
:return: target text without language IDs
:rtype: torch.Tensor
:return: target language IDs
:rtype: torch.Tensor (B, 1)
"""
tgt_lang_ids = None
if self.replace_sos:
# remove language ID in the beggining
tgt_lang_ids = ys_pad[:, 0].unsqueeze(1)
xs_pad = xs_pad[:, 1:]
ys_pad = ys_pad[:, 1:]
ilens -= 1
return xs_pad, ys_pad, tgt_lang_ids
[docs] def translate(self, x, trans_args, char_list, rnnlm=None):
"""E2E beam search
:param ndarray x: input source text feature (T, D)
:param Namespace trans_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
# 1. encoder
# make a utt list (1) to use the same interface for encoder
if self.replace_sos:
ilen = [len(x[0][1:])]
h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)))
else:
ilen = [len(x[0])]
h = to_device(self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)))
hs, _, _ = self.enc(self.dropout_emb_src(self.embed_src(h.unsqueeze(0))), ilen)
# 2. decoder
# decode the first utterance
y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm)
if prev:
self.train()
return y
[docs] def translate_batch(self, xs, trans_args, char_list, rnnlm=None):
"""E2E beam search
:param list xs: list of input source text feature arrays [(T_1, D), (T_2, D), ...]
:param Namespace trans_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
# 1. Encoder
if self.replace_sos:
ilens = np.fromiter((len(xx[1:]) for xx in xs), dtype=np.int64)
hs = [to_device(self, torch.from_numpy(xx[1:])) for xx in xs]
else:
ilens = np.fromiter((len(xx) for xx in xs), dtype=np.int64)
hs = [to_device(self, torch.from_numpy(xx)) for xx in xs]
xpad = pad_list(hs, self.pad)
hs_pad, hlens, _ = self.enc(self.dropout_emb_src(self.embed_src(xpad)), ilens)
# 2. Decoder
hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor
y = self.dec.recognize_beam_batch(hs_pad, hlens, None, trans_args, char_list, rnnlm)
if prev:
self.train()
return y
[docs] def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
# 1. Encoder
xs_pad, ys_pad, tgt_lang_ids = self.target_lang_biasing_train(xs_pad, ilens, ys_pad)
hpad, hlens, _ = self.enc(self.dropout_emb_src(self.embed_src(xs_pad)), ilens)
# 2. Decoder
att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad, tgt_lang_ids=tgt_lang_ids)
return att_ws