# encoding: utf-8
import chainer
import chainer.functions as F
from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention
from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm
from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
[docs]class DecoderLayer(chainer.Chain):
def __init__(self, n_units, d_units=0, h=8, dropout=0.1,
initialW=None, initial_bias=None):
super(DecoderLayer, self).__init__()
with self.init_scope():
self.self_attn = MultiHeadAttention(n_units, h, dropout=dropout,
initialW=initialW,
initial_bias=initial_bias)
self.src_attn = MultiHeadAttention(n_units, h, dropout=dropout,
initialW=initialW,
initial_bias=initial_bias)
self.feed_forward = PositionwiseFeedForward(n_units, d_units=d_units,
dropout=dropout,
initialW=initialW,
initial_bias=initial_bias)
self.norm1 = LayerNorm(n_units)
self.norm2 = LayerNorm(n_units)
self.norm3 = LayerNorm(n_units)
self.dropout = dropout
[docs] def forward(self, e, s, xy_mask, yy_mask, batch):
n_e = self.norm1(e)
n_e = self.self_attn(n_e, mask=yy_mask, batch=batch)
e = e + F.dropout(n_e, self.dropout)
n_e = self.norm2(e)
n_e = self.src_attn(n_e, s_var=s, mask=xy_mask, batch=batch)
e = e + F.dropout(n_e, self.dropout)
n_e = self.norm3(e)
n_e = self.feed_forward(n_e)
e = e + F.dropout(n_e, self.dropout)
return e