Source code for espnet.nets.chainer_backend.transformer.encoder_layer

# encoding: utf-8

import chainer

import chainer.functions as F

from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention
from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm
from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward


[docs]class EncoderLayer(chainer.Chain): def __init__(self, n_units, d_units=0, h=8, dropout=0.1, initialW=None, initial_bias=None): super(EncoderLayer, self).__init__() with self.init_scope(): self.self_attn = MultiHeadAttention(n_units, h, dropout=dropout, initialW=initialW, initial_bias=initial_bias) self.feed_forward = PositionwiseFeedForward(n_units, d_units=d_units, dropout=dropout, initialW=initialW, initial_bias=initial_bias) self.norm1 = LayerNorm(n_units) self.norm2 = LayerNorm(n_units) self.dropout = dropout self.n_units = n_units def __call__(self, e, xx_mask, batch): n_e = self.norm1(e) n_e = self.self_attn(n_e, mask=xx_mask, batch=batch) e = e + F.dropout(n_e, self.dropout) n_e = self.norm2(e) n_e = self.feed_forward(n_e) e = e + F.dropout(n_e, self.dropout) return e