Source code for espnet.nets.pytorch_backend.transformer.repeat

import torch


[docs]class MultiSequential(torch.nn.Sequential): """Multi-input multi-output torch.nn.Sequential"""
[docs] def forward(self, *args): for m in self: args = m(*args) return args
[docs]def repeat(N, fn): """repeat module N times :param int N: repeat time :param function fn: function to generate module :return: repeated modules :rtype: MultiSequential """ return MultiSequential(*[fn() for _ in range(N)])