Source code for espnet.nets.chainer_backend.rnn.attentions

import chainer
import chainer.functions as F
import chainer.links as L

import numpy as np

from espnet.nets.chainer_backend.nets_utils import linear_tensor


# dot product based attention
[docs]class AttDot(chainer.Chain): """Compute attention based on dot product. Args: eprojs (int | None): Dimension of input vectors from encoder. dunits (int | None): Dimension of input vectors for decoder. att_dim (int): Dimension of input vectors for attention. """ def __init__(self, eprojs, dunits, att_dim): super(AttDot, self).__init__() with self.init_scope(): self.mlp_enc = L.Linear(eprojs, att_dim) self.mlp_dec = L.Linear(dunits, att_dim) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None
[docs] def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None
def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): """Compute AttDot forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoder. dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. scaling (float): Scaling weight to make attention sharp. Returns: chainer.Variable: Weighted sum over flames. chainer.Variable: Attention weight. """ batch = len(enc_hs) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # utt x frame x att_dim self.pre_compute_enc_h = F.tanh( linear_tensor(self.mlp_enc, self.enc_h)) if dec_z is None: dec_z = chainer.Variable(self.xp.zeros( (batch, self.dunits), dtype=np.float32)) else: dec_z = F.reshape(dec_z, (batch, self.dunits)) # <phi (h_t), psi (s)> for all t u = F.broadcast_to(F.expand_dims(F.tanh(self.mlp_dec(dec_z)), 1), self.pre_compute_enc_h.shape) e = F.sum(self.pre_compute_enc_h * u, axis=2) # utt x frame # Applying a minus-large-number filter to make a probability value zero for a padded area # simply degrades the performance, and I gave up this implementation # Apply a scaling to make an attention sharp w = F.softmax(scaling * e) # weighted sum over flames # utt x hdim c = F.sum(self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1) return c, w
# location based attention
[docs]class AttLoc(chainer.Chain): """Compute location-based attention. Args: eprojs (int | None): Dimension of input vectors from encoder. dunits (int | None): Dimension of input vectors for decoder. att_dim (int): Dimension of input vectors for attention. aconv_chans (int): Number of channels of output arrays from convolutional layer. aconv_filts (int): Size of filters of convolutional layer. """ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): super(AttLoc, self).__init__() with self.init_scope(): self.mlp_enc = L.Linear(eprojs, att_dim) self.mlp_dec = L.Linear(dunits, att_dim, nobias=True) self.mlp_att = L.Linear(aconv_chans, att_dim, nobias=True) self.loc_conv = L.Convolution2D(1, aconv_chans, ksize=( 1, 2 * aconv_filts + 1), pad=(0, aconv_filts)) self.gvec = L.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.aconv_chans = aconv_chans
[docs] def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None
def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): """Compute AttLoc forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoders. dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. att_prev (chainer.Variable | None): Attention weight. scaling (float): Scaling weight to make attention sharp. Returns: chainer.Variable: Weighted sum over flames. chainer.Variable: Attention weight. """ batch = len(enc_hs) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # utt x frame x att_dim self.pre_compute_enc_h = linear_tensor(self.mlp_enc, self.enc_h) if dec_z is None: dec_z = chainer.Variable(self.xp.zeros( (batch, self.dunits), dtype=np.float32)) else: dec_z = F.reshape(dec_z, (batch, self.dunits)) # initialize attention weight with uniform dist. if att_prev is None: att_prev = [self.xp.full( hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) for hh in enc_hs] att_prev = [chainer.Variable(att) for att in att_prev] att_prev = F.pad_sequence(att_prev) # TODO(watanabe) use <chainer variable>.reshpae(), instead of F.reshape() # att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv( F.reshape(att_prev, (batch, 1, 1, self.h_length))) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = F.swapaxes(F.squeeze(att_conv, axis=2), 1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = linear_tensor(self.mlp_att, att_conv) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = F.broadcast_to( F.expand_dims(self.mlp_dec(dec_z), 1), self.pre_compute_enc_h.shape) # dot with gvec # utt x frame x att_dim -> utt x frame # TODO(watanabe) use batch_matmul e = F.squeeze(linear_tensor(self.gvec, F.tanh( att_conv + self.pre_compute_enc_h + dec_z_tiled)), axis=2) # Applying a minus-large-number filter to make a probability value zero for a padded area # simply degrades the performance, and I gave up this implementation # Apply a scaling to make an attention sharp w = F.softmax(scaling * e) # weighted sum over flames # utt x hdim c = F.sum(self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1) return c, w
[docs]class NoAtt(chainer.Chain): """Compute non-attention layer. This layer is a dummy attention layer to be compatible with other attention-based models. """ def __init__(self): super(NoAtt, self).__init__() self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None
[docs] def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None
def __call__(self, enc_hs, dec_z, att_prev): """Compute NoAtt forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoders. dec_z: Dummy. att_prev (chainer.Variable | None): Attention weight. Returns: chainer.Variable: Sum over flames. chainer.Variable: Attention weight. """ # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # initialize attention weight with uniform dist. if att_prev is None: att_prev = [self.xp.full( hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) for hh in enc_hs] att_prev = [chainer.Variable(att) for att in att_prev] att_prev = F.pad_sequence(att_prev) self.c = F.sum(self.enc_h * F.broadcast_to(F.expand_dims(att_prev, 2), self.enc_h.shape), axis=1) return self.c, att_prev
[docs]def att_for(args): """Returns an attention layer given the program arguments. Args: args (Namespace): The arguments. Returns: chainer.Chain: The corresponding attention module. """ if args.atype == 'dot': att = AttDot(args.eprojs, args.dunits, args.adim) elif args.atype == 'location': att = AttLoc(args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == 'noatt': att = NoAtt() else: raise NotImplementedError('chainer supports only noatt, dot, and location attention.') return att