Source code for espnet2.asr.transducer.utils

"""Utility functions for Transducer models."""

import torch

from espnet.nets.pytorch_backend.nets_utils import pad_list


[docs]def get_transducer_task_io( labels: torch.Tensor, encoder_out_lens: torch.Tensor, ignore_id: int = -1, blank_id: int = 0, ): """Get Transducer loss I/O. Args: labels: Label ID sequences. (B, L) encoder_out_lens: Encoder output lengths. (B,) ignore_id: Padding symbol ID. blank_id: Blank symbol ID. Return: decoder_in: Decoder inputs. (B, U) target: Target label ID sequences. (B, U) t_len: Time lengths. (B,) u_len: Label lengths. (B,) """ device = labels.device labels_unpad = [y[y != ignore_id] for y in labels] blank = labels[0].new([blank_id]) decoder_in = pad_list( [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id ).to(device) target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) if encoder_out_lens.dim() > 1: enc_mask = [m[m != 0] for m in encoder_out_lens] encoder_out_lens = list(map(int, [m.size(0) for m in enc_mask])) else: encoder_out_lens = list(map(int, encoder_out_lens)) t_len = torch.IntTensor(encoder_out_lens).to(device) u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) return decoder_in, target, t_len, u_len