Source code for espnet.nets.pytorch_backend.transducer.custom_decoder

"""Custom decoder definition for Transducer model."""

from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.utils import check_batch_states
from espnet.nets.pytorch_backend.transducer.utils import check_state
from espnet.nets.pytorch_backend.transducer.utils import pad_sequence
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis
from espnet.nets.transducer_decoder_interface import Hypothesis
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface

[docs]class CustomDecoder(TransducerDecoderInterface, torch.nn.Module): """Custom decoder module for Transducer model. Args: odim: Output dimension. dec_arch: Decoder block architecture (type and parameters). input_layer: Input layer type. repeat_block: Number of times dec_arch is repeated. joint_activation_type: Type of activation for joint network. positional_encoding_type: Positional encoding type. positionwise_layer_type: Positionwise layer type. positionwise_activation_type: Positionwise activation type. input_layer_dropout_rate: Dropout rate for input layer. blank_id: Blank symbol ID. """ def __init__( self, odim: int, dec_arch: List, input_layer: str = "embed", repeat_block: int = 0, joint_activation_type: str = "tanh", positional_encoding_type: str = "abs_pos", positionwise_layer_type: str = "linear", positionwise_activation_type: str = "relu", input_layer_dropout_rate: float = 0.0, blank_id: int = 0, ): """Construct a CustomDecoder object.""" torch.nn.Module.__init__(self) self.embed, self.decoders, ddim, _ = build_blocks( "decoder", odim, input_layer, dec_arch, repeat_block=repeat_block, positional_encoding_type=positional_encoding_type, positionwise_layer_type=positionwise_layer_type, positionwise_activation_type=positionwise_activation_type, input_layer_dropout_rate=input_layer_dropout_rate, padding_idx=blank_id, ) self.after_norm = LayerNorm(ddim) self.dlayers = len(self.decoders) self.dunits = ddim self.odim = odim self.blank_id = blank_id
[docs] def set_device(self, device: torch.device): """Set GPU device to use. Args: device: Device ID. """ self.device = device
[docs] def init_state( self, batch_size: Optional[int] = None, ) -> List[Optional[torch.Tensor]]: """Initialize decoder states. Args: batch_size: Batch size. Returns: state: Initial decoder hidden states. [N x None] """ state = [None] * self.dlayers return state
[docs] def forward( self, dec_input: torch.Tensor, dec_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode label ID sequences. Args: dec_input: Label ID sequences. (B, U) dec_mask: Label mask sequences. (B, U) Return: dec_output: Decoder output sequences. (B, U, D_dec) dec_output_mask: Mask of decoder output sequences. (B, U) """ dec_input = self.embed(dec_input) dec_output, dec_mask = self.decoders(dec_input, dec_mask) dec_output = self.after_norm(dec_output) return dec_output, dec_mask
[docs] def score( self, hyp: Hypothesis, cache: Dict[str, Any] ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypothesis. Args: hyp: Hypothesis. cache: Pairs of (dec_out, dec_state) for each label sequence. (key) Returns: dec_out: Decoder output sequence. (1, D_dec) dec_state: Decoder hidden states. [N x (1, U, D_dec)] lm_label: Label ID for LM. (1,) """ labels = torch.tensor([hyp.yseq], device=self.device) lm_label = labels[:, -1] str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: dec_out, dec_state = cache[str_labels] else: dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0) new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id) dec_out = self.embed(labels) dec_state = [] for s, decoder in zip(new_state, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) dec_state.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) cache[str_labels] = (dec_out, dec_state) return dec_out[0], dec_state, lm_label
[docs] def batch_score( self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]], dec_states: List[Optional[torch.Tensor]], cache: Dict[str, Any], use_lm: bool, ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]: """One-step forward hypotheses. Args: hyps: Hypotheses. dec_states: Decoder hidden states. [N x (B, U, D_dec)] cache: Pairs of (h_dec, dec_states) for each label sequences. (keys) use_lm: Whether to compute label ID sequences for LM. Returns: dec_out: Decoder output sequences. (B, D_dec) dec_states: Decoder hidden states. [N x (B, U, D_dec)] lm_labels: Label ID sequences for LM. (B,) """ final_batch = len(hyps) process = [] done = [None] * final_batch for i, hyp in enumerate(hyps): str_labels = "_".join(list(map(str, hyp.yseq))) if str_labels in cache: done[i] = cache[str_labels] else: process.append((str_labels, hyp.yseq, hyp.dec_state)) if process: labels = pad_sequence([p[1] for p in process], self.blank_id) labels = torch.LongTensor(labels, device=self.device) p_dec_states = self.create_batch_states( self.init_state(), [p[2] for p in process], labels, ) dec_out = self.embed(labels) dec_out_mask = ( subsequent_mask(labels.size(-1)) .unsqueeze_(0) .expand(len(process), -1, -1) ) new_states = [] for s, decoder in zip(p_dec_states, self.decoders): dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s) new_states.append(dec_out) dec_out = self.after_norm(dec_out[:, -1]) j = 0 for i in range(final_batch): if done[i] is None: state = self.select_state(new_states, j) done[i] = (dec_out[j], state) cache[process[j][0]] = (dec_out[j], state) j += 1 dec_out = torch.stack([d[0] for d in done]) dec_states = self.create_batch_states( dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps] ) if use_lm: lm_labels = torch.LongTensor( [hyp.yseq[-1] for hyp in hyps], device=self.device ) return dec_out, dec_states, lm_labels return dec_out, dec_states, None
[docs] def select_state( self, states: List[Optional[torch.Tensor]], idx: int ) -> List[Optional[torch.Tensor]]: """Get specified ID state from decoder hidden states. Args: states: Decoder hidden states. [N x (B, U, D_dec)] idx: State ID to extract. Returns: state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)] """ if states[0] is None: return states state_idx = [states[layer][idx] for layer in range(self.dlayers)] return state_idx
[docs] def create_batch_states( self, states: List[Optional[torch.Tensor]], new_states: List[Optional[torch.Tensor]], check_list: List[List[int]], ) -> List[Optional[torch.Tensor]]: """Create decoder hidden states sequences. Args: states: Decoder hidden states. [N x (B, U, D_dec)] new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]] check_list: Label ID sequences. Returns: states: New decoder hidden states. [N x (B, U, D_dec)] """ if new_states[0][0] is None: return states max_len = max(len(elem) for elem in check_list) - 1 for layer in range(self.dlayers): states[layer] = check_batch_states( [s[layer] for s in new_states], max_len, self.blank_id ) return states