import torch
[docs]class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1):
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
[docs] def forward(self, x):
"""Apply layer normalization
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)