import torch
from espnet2.enh.decoder.abs_decoder import AbsDecoder
[docs]class ConvDecoder(AbsDecoder):
"""Transposed Convolutional decoder for speech enhancement and separation"""
def __init__(
self,
channel: int,
kernel_size: int,
stride: int,
):
super().__init__()
self.convtrans1d = torch.nn.ConvTranspose1d(
channel, 1, kernel_size, bias=False, stride=stride
)
[docs] def forward(self, input: torch.Tensor, ilens: torch.Tensor):
"""Forward.
Args:
input (torch.Tensor): spectrum [Batch, T, F]
ilens (torch.Tensor): input lengths [Batch]
"""
input = input.transpose(1, 2)
batch_size = input.shape[0]
wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max()))
wav = wav.squeeze(1)
return wav, ilens