"""Transducer and transducer with attention implementation for training and decoding."""
import six
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.nets_utils import to_device
[docs]class DecoderRNNT(torch.nn.Module):
"""RNN-T Decoder module.
Args:
eprojs (int): # encoder projection units
odim (int): dimension of outputs
dtype (str): gru or lstm
dlayers (int): # prediction layers
dunits (int): # prediction units
blank (int): blank symbol id
embed_dim (init): dimension of embeddings
joint_dim (int): dimension of joint space
dropout (float): dropout rate
dropout_embed (float): embedding dropout rate
rnnt_type (str): type of rnn-t implementation
"""
def __init__(self, eprojs, odim, dtype, dlayers, dunits, blank,
embed_dim, joint_dim, dropout=0.0, dropout_embed=0.0,
rnnt_type='warp-transducer'):
"""Transducer initializer."""
super(DecoderRNNT, self).__init__()
self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank)
self.dropout_embed = torch.nn.Dropout(p=dropout_embed)
if dtype == "lstm":
dec_net = torch.nn.LSTM
else:
dec_net = torch.nn.GRU
self.decoder = torch.nn.ModuleList([dec_net(embed_dim, dunits, 1,
bias=True, batch_first=True,
bidirectional=False)])
self.dropout_dec = torch.nn.ModuleList([torch.nn.Dropout(p=dropout)])
for _ in six.moves.range(1, dlayers):
self.decoder += [dec_net(dunits, dunits, 1,
bias=True, batch_first=True,
bidirectional=False)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
if rnnt_type == 'warp-transducer':
from warprnnt_pytorch import RNNTLoss
self.rnnt_loss = RNNTLoss(blank=blank)
else:
raise NotImplementedError
self.lin_enc = torch.nn.Linear(eprojs, joint_dim)
self.lin_dec = torch.nn.Linear(dunits, joint_dim, bias=False)
self.lin_out = torch.nn.Linear(joint_dim, odim)
self.dlayers = dlayers
self.dunits = dunits
self.dtype = dtype
self.embed_dim = embed_dim
self.joint_dim = joint_dim
self.odim = odim
self.rnnt_type = rnnt_type
self.ignore_id = -1
self.blank = blank
[docs] def zero_state(self, ey):
"""Initialize decoder states.
Args:
ey (torch.Tensor): batch of input features (B, Lmax, Emb_dim)
Returns:
(list): list of L zero-init hidden and cell state (1, B, Hdec)
"""
z_list = [ey.new_zeros(1, ey.size(0), self.dunits)]
c_list = [ey.new_zeros(1, ey.size(0), self.dunits)]
for _ in six.moves.range(1, self.dlayers):
z_list.append(ey.new_zeros(1, ey.size(0), self.dunits))
c_list.append(ey.new_zeros(1, ey.size(0), self.dunits))
return (z_list, c_list)
[docs] def rnn_forward(self, ey, dstate):
"""RNN forward.
Args:
ey (torch.Tensor): batch of input features (B, Lmax, Emb_dim)
dstate (list): list of L input hidden and cell state (1, B, Hdec)
Returns:
output (torch.Tensor): batch of output features (B, Lmax, Hdec)
dstate (list): list of L output hidden and cell state (1, B, Hdec)
"""
if dstate is None:
z_prev, c_prev = self.zero_state(ey)
else:
z_prev, c_prev = dstate
z_list, c_list = self.zero_state(ey)
if self.dtype == "lstm":
y, (z_list[0], c_list[0]) = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in six.moves.range(1, self.dlayers):
y, (z_list[l], c_list[l]) = self.decoder[l](y, (z_prev[l], c_prev[l]))
else:
y, z_list[0] = self.decoder[0](ey, z_prev[0])
for l in six.moves.range(1, self.dlayers):
y, z_list[l] = self.decoder[l](y, z_prev[l])
return y, (z_list, c_list)
[docs] def joint(self, h_enc, h_dec):
"""Joint computation of z.
Args:
h_enc (torch.Tensor): batch of expanded hidden state (B, T, 1, Henc)
h_dec (torch.Tensor): batch of expanded hidden state (B, 1, U, Hdec)
Returns:
z (torch.Tensor): output (B, T, U, odim)
"""
z = torch.tanh(self.lin_enc(h_enc) + self.lin_dec(h_dec))
z = self.lin_out(z)
return z
[docs] def forward(self, hs_pad, hlens, ys_pad):
"""Forward function for transducer.
Args:
hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D)
hlens (torch.Tensor): batch of lengths of hidden state sequences (B)
ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)
Returns:
loss (float): rnnt loss value
"""
ys = [y[y != self.ignore_id] for y in ys_pad]
hlens = list(map(int, hlens))
blank = ys[0].new([self.blank])
ys_in = [torch.cat([blank, y], dim=0) for y in ys]
ys_in_pad = pad_list(ys_in, self.blank)
eys = self.dropout_embed(self.embed(ys_in_pad))
h_dec, _ = self.rnn_forward(eys, None)
h_enc = hs_pad.unsqueeze(2)
h_dec = h_dec.unsqueeze(1)
z = self.joint(h_enc, h_dec)
y = pad_list(ys, self.blank).type(torch.int32)
z_len = to_device(self, torch.IntTensor(hlens))
y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))
loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))
return loss
[docs] def recognize(self, h, recog_args):
"""Greedy search implementation.
Args:
h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
recog_args (Namespace): argument Namespace containing options
Returns:
hyp (list of dicts): 1-best decoding results
"""
hyp = {'score': 0.0, 'yseq': [self.blank]}
ey = torch.zeros((1, 1, self.embed_dim))
y, (z_list, c_list) = self.rnn_forward(ey, None)
for hi in h:
ytu = F.log_softmax(self.joint(hi, y[0][0]), dim=0)
logp, pred = torch.max(ytu, dim=0)
if pred != self.blank:
hyp['yseq'].append(int(pred))
hyp['score'] += float(logp)
eys = torch.full((1, 1), hyp['yseq'][-1], dtype=torch.long)
ey = self.dropout_embed(self.embed(eys))
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
return [hyp]
[docs] def recognize_beam(self, h, recog_args, rnnlm=None):
"""Beam search implementation.
Args:
h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
recog_args (Namespace): argument Namespace containing options
rnnlm (torch.nn.Module): language module
Returns:
nbest_hyps (list of dicts): n-best decoding results
"""
beam = recog_args.beam_size
k_range = min(beam, self.odim)
nbest = recog_args.nbest
normscore = recog_args.score_norm_transducer
ey = torch.zeros((1, 1, self.embed_dim))
y, dstate = self.rnn_forward(ey, None)
if rnnlm:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'dstate': dstate, 'lm_state': None}]
else:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'dstate': dstate}]
for i, hi in enumerate(h):
hyps = kept_hyps
kept_hyps = []
while True:
new_hyp = max(hyps, key=lambda x: x['score'])
hyps.remove(new_hyp)
vy = to_device(self, torch.full((1, 1), new_hyp['yseq'][-1], dtype=torch.long))
ey = self.dropout_embed(self.embed(vy))
y, dstate = self.rnn_forward(ey, new_hyp['dstate'])
ytu = F.log_softmax(self.joint(hi, y[0][0]), dim=0)
if rnnlm:
rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0])
for k in six.moves.range(self.odim):
beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
'yseq': new_hyp['yseq'][:],
'dstate': new_hyp['dstate']}
if rnnlm:
beam_hyp['lm_state'] = new_hyp['lm_state']
if k == self.blank:
kept_hyps.append(beam_hyp)
else:
beam_hyp['dstate'] = dstate
beam_hyp['yseq'].append(int(k))
if rnnlm:
beam_hyp['lm_state'] = rnnlm_state
beam_hyp['score'] += recog_args.lm_weight * rnnlm_scores[0][ytu[k]]
hyps.append(beam_hyp)
if len(kept_hyps) >= k_range:
break
if normscore:
nbest_hyps = sorted(
kept_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
else:
nbest_hyps = sorted(
kept_hyps, key=lambda x: x['score'], reverse=True)[:nbest]
return nbest_hyps
[docs]class DecoderRNNTAtt(torch.nn.Module):
"""RNNT-Att Decoder module.
Args:
eprojs (int): # encoder projection units
odim (int): dimension of outputs
dtype (str): gru or lstm
dlayers (int): # decoder layers
dunits (int): # decoder units
blank (int): blank symbol id
att (torch.nn.Module): attention module
embed_dim (int): dimension of embeddings
joint_dim (int): dimension of joint space
dropout (float): dropout rate
dropout_embed (float): embedding dropout rate
rnnt_type (str): type of rnnt implementation
"""
def __init__(self, eprojs, odim, dtype, dlayers, dunits, blank, att,
embed_dim, joint_dim, dropout=0.0, dropout_embed=0.0,
rnnt_type='warp-transducer'):
"""Transducer with attention initializer."""
super(DecoderRNNTAtt, self).__init__()
self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank)
self.dropout_emb = torch.nn.Dropout(p=dropout_embed)
if dtype == "lstm":
dec_net = torch.nn.LSTMCell
else:
dec_net = torch.nn.GRUCell
self.decoder = torch.nn.ModuleList([dec_net((embed_dim + eprojs), dunits)])
self.dropout_dec = torch.nn.ModuleList([torch.nn.Dropout(p=dropout)])
for _ in six.moves.range(1, dlayers):
self.decoder += [dec_net(dunits, dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
if rnnt_type == 'warp-transducer':
from warprnnt_pytorch import RNNTLoss
self.rnnt_loss = RNNTLoss(blank=blank)
else:
raise NotImplementedError
self.lin_enc = torch.nn.Linear(eprojs, joint_dim)
self.lin_dec = torch.nn.Linear(dunits, joint_dim, bias=False)
self.lin_out = torch.nn.Linear(joint_dim, odim)
self.att = att
self.dtype = dtype
self.dlayers = dlayers
self.dunits = dunits
self.embed_dim = embed_dim
self.joint_dim = joint_dim
self.odim = odim
self.rnnt_type = rnnt_type
self.ignore_id = -1
self.blank = blank
[docs] def zero_state(self, ey):
"""Initialize decoder states.
Args:
ey (torch.Tensor): batch of input features (B, (Emb_dim + Eprojs))
Return:
z_list : list of L zero-init hidden state (B, Hdec)
c_list : list of L zero-init cell state (B, Hdec)
"""
z_list = [ey.new_zeros(ey.size(0), self.dunits)]
c_list = [ey.new_zeros(ey.size(0), self.dunits)]
for _ in six.moves.range(1, self.dlayers):
z_list.append(ey.new_zeros(ey.size(0), self.dunits))
c_list.append(ey.new_zeros(ey.size(0), self.dunits))
return z_list, c_list
[docs] def rnn_forward(self, ey, dstate):
"""RNN forward.
Args:
ey (torch.Tensor): batch of input features (B, (Emb_dim + Eprojs))
dstate (list): list of L input hidden and cell state (B, Hdec)
Returns:
y (torch.Tensor): decoder output for one step (B, Hdec)
(list): list of L output hidden and cell state (B, Hdec)
"""
if dstate is None:
z_prev, c_prev = self.zero_state(ey)
else:
z_prev, c_prev = dstate
z_list, c_list = self.zero_state(ey)
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in six.moves.range(1, self.dlayers):
z_list[l], c_list[l] = self.decoder[l](
self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l]))
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for l in six.moves.range(1, self.dlayers):
z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l])
y = self.dropout_dec[-1](z_list[-1])
return y, (z_list, c_list)
[docs] def joint(self, h_enc, h_dec):
"""Joint computation of z.
Args:
h_enc (torch.Tensor): batch of expanded hidden state (B, T, 1, Henc)
h_dec (torch.Tensor): batch of expanded hidden state (B, 1, U, Hdec)
Returns:
z (torch.Tensor): output (B, T, U, odim)
"""
z = torch.tanh(self.lin_enc(h_enc) + self.lin_dec(h_dec))
z = self.lin_out(z)
return z
[docs] def forward(self, hs_pad, hlens, ys_pad):
"""Forward function for transducer with attention.
Args:
hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D)
hlens (torch.Tensor): batch of lengths of hidden state sequences (B)
ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)
Returns:
loss (torch.Tensor): rnnt-att loss value
"""
ys = [y[y != self.ignore_id] for y in ys_pad]
hlens = list(map(int, hlens))
blank = ys[0].new([self.blank])
ys_in = [torch.cat([blank, y], dim=0) for y in ys]
ys_in_pad = pad_list(ys_in, self.blank)
olength = ys_in_pad.size(1)
att_w = None
self.att[0].reset()
z_list, c_list = self.zero_state(hs_pad)
eys = self.dropout_emb(self.embed(ys_in_pad))
z_all = []
for i in six.moves.range(olength):
att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
z_all.append(y)
h_dec = torch.stack(z_all, dim=1)
h_enc = hs_pad.unsqueeze(2)
h_dec = h_dec.unsqueeze(1)
z = self.joint(h_enc, h_dec)
y = pad_list(ys, self.blank).type(torch.int32)
z_len = to_device(self, torch.IntTensor(hlens))
y_len = to_device(self, torch.IntTensor([_y.size(0) for _y in ys]))
loss = to_device(self, self.rnnt_loss(z, y, z_len, y_len))
return loss
[docs] def recognize(self, h, recog_args):
"""Greedy search implementation.
Args:
h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
recog_args (Namespace): argument Namespace containing options
Returns:
hyp (list of dicts): 1-best decoding results
"""
self.att[0].reset()
z_list, c_list = self.zero_state(h.unsqueeze(0))
eys = torch.zeros((1, self.embed_dim))
att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)],
self.dropout_dec[0](z_list[0]), None)
ey = torch.cat((eys, att_c), dim=1)
hyp = {'score': 0.0, 'yseq': [self.blank]}
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
for hi in h:
ytu = F.log_softmax(self.joint(hi, y[0]), dim=0)
logp, pred = torch.max(ytu, dim=0)
if pred != self.blank:
hyp['yseq'].append(int(pred))
hyp['score'] += float(logp)
eys = torch.full((1, 1), hyp['yseq'][-1], dtype=torch.long)
ey = self.dropout_emb(self.embed(eys))
att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)],
self.dropout_dec[0](z_list[0]),
att_w)
ey = torch.cat((ey[0], att_c), dim=1)
y, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
return [hyp]
[docs] def recognize_beam(self, h, recog_args, rnnlm=None):
"""Beam search recognition.
Args:
h (torch.Tensor): encoder hidden state sequences (Tmax, Henc)
recog_args (Namespace): argument Namespace containing options
rnnlm (torch.nn.Module): language module
Results:
nbest_hyps (list of dicts): n-best decoding results
"""
beam = recog_args.beam_size
k_range = min(beam, self.odim)
nbest = recog_args.nbest
normscore = recog_args.score_norm_transducer
self.att[0].reset()
z_list, c_list = self.zero_state(h.unsqueeze(0))
eys = torch.zeros((1, self.embed_dim))
att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)],
self.dropout_dec[0](z_list[0]), None)
ey = torch.cat((eys, att_c), dim=1)
_, (z_list, c_list) = self.rnn_forward(ey, None)
if rnnlm:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
'c_prev': c_list, 'a_prev': None, 'lm_state': None}]
else:
kept_hyps = [{'score': 0.0, 'yseq': [self.blank], 'z_prev': z_list,
'c_prev': c_list, 'a_prev': None}]
for i, hi in enumerate(h):
hyps = kept_hyps
kept_hyps = []
while True:
new_hyp = max(hyps, key=lambda x: x['score'])
hyps.remove(new_hyp)
vy = to_device(self, torch.full((1, 1), new_hyp['yseq'][-1], dtype=torch.long))
ey = self.dropout_emb(self.embed(vy))
att_c, att_w = self.att[0](h.unsqueeze(0), [h.size(0)],
self.dropout_dec[0](new_hyp['z_prev'][0]),
new_hyp['a_prev'])
ey = torch.cat((ey[0], att_c), dim=1)
y, (z_list, c_list) = self.rnn_forward(ey, (new_hyp['z_prev'], new_hyp['c_prev']))
ytu = F.log_softmax(self.joint(hi, y[0]), dim=0)
if rnnlm:
rnnlm_state, rnnlm_scores = rnnlm.predict(new_hyp['lm_state'], vy[0])
for k in six.moves.range(self.odim):
beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
'yseq': new_hyp['yseq'][:],
'z_prev': new_hyp['z_prev'],
'c_prev': new_hyp['c_prev'],
'a_prev': new_hyp['a_prev']}
if rnnlm:
beam_hyp['lm_state'] = new_hyp['lm_state']
if k == self.blank:
kept_hyps.append(beam_hyp)
else:
beam_hyp['z_prev'] = z_list[:]
beam_hyp['c_prev'] = c_list[:]
beam_hyp['a_prev'] = att_w[:]
beam_hyp['yseq'].append(int(k))
if rnnlm:
beam_hyp['lm_state'] = rnnlm_state
beam_hyp['score'] += recog_args.lm_weight * rnnlm_scores[0][ytu[k]]
hyps.append(beam_hyp)
if len(kept_hyps) >= k_range:
break
if normscore:
nbest_hyps = sorted(
kept_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
else:
nbest_hyps = sorted(
kept_hyps, key=lambda x: x['score'], reverse=True)[:nbest]
return nbest_hyps
[docs] def calculate_all_attentions(self, hs_pad, hlens, ys_pad):
"""Calculate all of attentions.
Args:
hs_pad (torch.Tensor): batch of padded hidden state sequences (B, Tmax, D)
hlens (torch.Tensor): batch of lengths of hidden state sequences (B)
ys_pad (torch.Tensor): batch of padded character id sequence tensor (B, Lmax)
Returns:
att_ws (ndarray): attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
"""
ys = [y[y != self.ignore_id] for y in ys_pad]
hlens = list(map(int, hlens))
blank = ys[0].new([self.blank])
ys_in = [torch.cat([blank, y], dim=0) for y in ys]
ys_in_pad = pad_list(ys_in, self.blank)
olength = ys_in_pad.size(1)
att_w = None
att_ws = []
self.att[0].reset()
eys = self.dropout_emb(self.embed(ys_in_pad))
z_list, c_list = self.zero_state(eys)
for i in six.moves.range(olength):
att_c, att_w = self.att[0](hs_pad, hlens, self.dropout_dec[0](z_list[0]), att_w)
ey = torch.cat((eys[:, i, :], att_c), dim=1)
_, (z_list, c_list) = self.rnn_forward(ey, (z_list, c_list))
att_ws.append(att_w)
att_ws = att_to_numpy(att_ws, self.att[0])
return att_ws
[docs]def decoder_for(args, odim, att=None, blank=0):
"""Transducer mode selector."""
if args.rnnt_mode == 'rnnt':
return DecoderRNNT(args.eprojs, odim, args.dtype, args.dlayers, args.dunits,
blank, args.dec_embed_dim, args.joint_dim,
args.dropout_rate_decoder, args.dropout_rate_embed_decoder,
args.rnnt_type)
elif args.rnnt_mode == 'rnnt-att':
return DecoderRNNTAtt(args.eprojs, odim, args.dtype, args.dlayers, args.dunits,
blank, att, args.dec_embed_dim, args.joint_dim,
args.dropout_rate_decoder, args.dropout_rate_embed_decoder,
args.rnnt_type)