"""(RNN-)Transducer decoder definition."""
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.transducer.beam_search_transducer import ExtendedHypothesis
from espnet2.asr.transducer.beam_search_transducer import Hypothesis
[docs]class TransducerDecoder(AbsDecoder):
"""(RNN-)Transducer decoder module.
Args:
vocab_size: Output dimension.
layers_type: (RNN-)Decoder layers type.
num_layers: Number of decoder layers.
hidden_size: Number of decoder units per layer.
dropout: Dropout rate for decoder layers.
dropout_embed: Dropout rate for embedding layer.
embed_pad: Embed/Blank symbol ID.
"""
def __init__(
self,
vocab_size: int,
rnn_type: str = "lstm",
num_layers: int = 1,
hidden_size: int = 320,
dropout: float = 0.0,
dropout_embed: float = 0.0,
embed_pad: int = 0,
):
assert check_argument_types()
if rnn_type not in {"lstm", "gru"}:
raise ValueError(f"Not supported: rnn_type={rnn_type}")
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=dropout_embed)
dec_net = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
self.decoder = torch.nn.ModuleList(
[
dec_net(hidden_size, hidden_size, 1, batch_first=True)
for _ in range(num_layers)
]
)
self.dropout_dec = torch.nn.ModuleList(
[torch.nn.Dropout(p=dropout) for _ in range(num_layers)]
)
self.dlayers = num_layers
self.dunits = hidden_size
self.dtype = rnn_type
self.odim = vocab_size
self.ignore_id = -1
self.blank_id = embed_pad
self.device = next(self.parameters()).device
[docs] def set_device(self, device: torch.device):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
[docs] def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
[docs] def rnn_forward(
self,
sequence: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
sequence: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
Returns:
sequence: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states. (N, B, D_dec), (N, B, D_dec))
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(sequence.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
sequence, (
h_next[layer : layer + 1],
c_next[layer : layer + 1],
) = self.decoder[layer](
sequence, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
)
else:
sequence, h_next[layer : layer + 1] = self.decoder[layer](
sequence, hx=h_prev[layer : layer + 1]
)
sequence = self.dropout_dec[layer](sequence)
return sequence, (h_next, c_next)
[docs] def forward(self, labels: torch.Tensor) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
Returns:
dec_out: Decoder output sequences. (B, T, U, D_dec)
"""
init_state = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, _ = self.rnn_forward(dec_embed, init_state)
return dec_out
[docs] def score(
self, hyp: Hypothesis, cache: Dict[str, Any]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
new_state: Decoder hidden states. ((N, 1, D_dec), (N, 1, D_dec))
label: Label ID for LM. (1,)
"""
label = torch.full((1, 1), hyp.yseq[-1], dtype=torch.long, device=self.device)
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
dec_out, dec_state = cache[str_labels]
else:
dec_emb = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_emb, hyp.dec_state)
cache[str_labels] = (dec_out, dec_state)
return dec_out[0][0], dec_state, label[0]
[docs] def batch_score(
self,
hyps: Union[List[Hypothesis], List[ExtendedHypothesis]],
dec_states: Tuple[torch.Tensor, Optional[torch.Tensor]],
cache: Dict[str, Any],
use_lm: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
cache: Pairs of (dec_out, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch = len(hyps)
process = []
done = [None] * final_batch
for i, hyp in enumerate(hyps):
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
done[i] = cache[str_labels]
else:
process.append((str_labels, hyp.yseq[-1], hyp.dec_state))
if process:
labels = torch.LongTensor([[p[1]] for p in process], device=self.device)
p_dec_states = self.create_batch_states(
self.init_state(labels.size(0)), [p[2] for p in process]
)
dec_emb = self.embed(labels)
dec_out, new_states = self.rnn_forward(dec_emb, p_dec_states)
j = 0
for i in range(final_batch):
if done[i] is None:
state = self.select_state(new_states, j)
done[i] = (dec_out[j], state)
cache[process[j][0]] = (dec_out[j], state)
j += 1
dec_out = torch.cat([d[0] for d in done], dim=0)
dec_states = self.create_batch_states(dec_states, [d[1] for d in done])
if use_lm:
lm_labels = torch.LongTensor(
[h.yseq[-1] for h in hyps], device=self.device
).view(final_batch, 1)
return dec_out, dec_states, lm_labels
return dec_out, dec_states, None
[docs] def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID.
((N, 1, D_dec), (N, 1, D_dec))
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
[docs] def create_batch_states(
self,
states: Tuple[torch.Tensor, Optional[torch.Tensor]],
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
check_list: Optional[List] = None,
) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Create decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec))]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)