Source code for espnet.nets.chainer_backend.transformer.positionwise_feed_forward

# encoding: utf-8

import chainer

import chainer.functions as F
import chainer.links as L

import numpy as np


[docs]class PositionwiseFeedForward(chainer.Chain): def __init__(self, n_units, d_units=0, dropout=0.1, initialW=None, initial_bias=None): super(PositionwiseFeedForward, self).__init__() n_inner_units = d_units if d_units > 0 else n_units * 4 with self.init_scope(): stvd = 1. / np.sqrt(n_units) self.w_1 = L.Linear(n_units, n_inner_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd)) stvd = 1. / np.sqrt(n_inner_units) self.w_2 = L.Linear(n_inner_units, n_units, initialW=initialW(scale=stvd), initial_bias=initial_bias(scale=stvd)) self.act = F.relu self.dropout = dropout def __call__(self, e): e = F.dropout(self.act(self.w_1(e)), self.dropout) return self.w_2(e)