import torch
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import PositionwiseFeedForward
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.scorer_interface import ScorerInterface
[docs]class Decoder(ScorerInterface, torch.nn.Module):
"""Transfomer decoder module
:param int odim: output dim
:param int attention_dim: dimention of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate for attention
:param str or torch.nn.Module input_layer: input layer type
:param bool use_output_layer: whether to use output layer
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(self, odim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0,
input_layer="embed",
use_output_layer=True,
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False):
torch.nn.Module.__init__(self)
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(odim, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(odim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
self.normalize_before = normalize_before
self.decoders = repeat(
num_blocks,
lambda: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after
)
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, odim)
else:
self.output_layer = None
[docs] def forward(self, tgt, tgt_mask, memory, memory_mask):
"""forward decoder
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
:param torch.Tensor tgt_mask: input token mask, uint8 (batch, maxlen_out)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param torch.Tensor memory_mask: encoded memory mask, uint8 (batch, maxlen_in)
:return x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True,
final block outputs (batch, maxlen_out, attention_dim) in the other cases
:rtype: torch.Tensor
:return tgt_mask: score mask before softmax (batch, maxlen_out)
:rtype: torch.Tensor
"""
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
return x, tgt_mask
[docs] def recognize(self, tgt, tgt_mask, memory):
"""recognize one step
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
:param torch.Tensor tgt_mask: input token mask, uint8 (batch, maxlen_out)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:return x: decoded token score before softmax (batch, maxlen_out, token)
:rtype: torch.Tensor
"""
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, None)
if self.normalize_before:
x_ = self.after_norm(x[:, -1])
else:
x_ = x[:, -1]
if self.output_layer is not None:
return torch.log_softmax(self.output_layer(x_), dim=-1)
else:
return x_
# beam search API (see ScorerInterface)
[docs] def score(self, ys, state, x):
# TODO(karita) cache previous attentions in state
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp = self.recognize(ys.unsqueeze(0), ys_mask, x.unsqueeze(0))
return logp.squeeze(0), None