Source code for espnet.nets.pytorch_backend.transformer.layer_norm

import torch


[docs]class LayerNorm(torch.nn.LayerNorm): """Layer normalization module :param int nout: output dim size :param int dim: dimension to be normalized """ def __init__(self, nout, dim=-1): super(LayerNorm, self).__init__(nout, eps=1e-12) self.dim = dim
[docs] def forward(self, x): """Apply layer normalization :param torch.Tensor x: input tensor :return: layer normalized tensor :rtype torch.Tensor """ if self.dim == -1: return super(LayerNorm, self).forward(x) return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)