Source code for espnet.nets.pytorch_backend.tacotron2.decoder

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Tacotron2 decoder related modules."""

import six

import torch
import torch.nn.functional as F

from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA


[docs]def decoder_init(m): """Initialize decoder parameters.""" if isinstance(m, torch.nn.Conv1d): torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain('tanh'))
[docs]class ZoneOutCell(torch.nn.Module): """ZoneOut Cell module. This is a module of zoneout described in `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_. This code is modified from `eladhoffer/seq2seq.pytorch`_. Examples: >>> lstm = torch.nn.LSTMCell(16, 32) >>> lstm = ZoneOutCell(lstm, 0.5) .. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`: https://arxiv.org/abs/1606.01305 .. _`eladhoffer/seq2seq.pytorch`: https://github.com/eladhoffer/seq2seq.pytorch """ def __init__(self, cell, zoneout_rate=0.1): """Initialize zone out cell module. Args: cell (torch.nn.Module): Pytorch recurrent cell module e.g. `torch.nn.Module.LSTMCell`. zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0. """ super(ZoneOutCell, self).__init__() self.cell = cell self.hidden_size = cell.hidden_size self.zoneout_rate = zoneout_rate if zoneout_rate > 1.0 or zoneout_rate < 0.0: raise ValueError("zoneout probability must be in the range from 0.0 to 1.0.")
[docs] def forward(self, inputs, hidden): """Calculate forward propagation. Args: inputs (Tensor): Batch of input tensor (B, input_size). hidden (tuple): - Tensor: Batch of initial hidden states (B, hidden_size). - Tensor: Batch of initial cell states (B, hidden_size). Returns: tuple: - Tensor: Batch of next hidden states (B, hidden_size). - Tensor: Batch of next cell states (B, hidden_size). """ next_hidden = self.cell(inputs, hidden) next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate) return next_hidden
def _zoneout(self, h, next_h, prob): # apply recursively if isinstance(h, tuple): num_h = len(h) if not isinstance(prob, tuple): prob = tuple([prob] * num_h) return tuple([self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)]) if self.training: mask = h.new(*h.size()).bernoulli_(prob) return mask * h + (1 - mask) * next_h else: return prob * h + (1 - prob) * next_h
[docs]class Prenet(torch.nn.Module): """Prenet module for decoder of Spectrogram prediction network. This is a module of Prenet in the decoder of Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Prenet preforms nonlinear conversion of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions. Note: This module alway applies dropout even in evaluation See the detail in _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5): """Initialize prenet module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. n_layers (int, optional): The number of prenet layers. n_units (int, optional): The number of prenet units. """ super(Prenet, self).__init__() self.dropout_rate = dropout_rate self.prenet = torch.nn.ModuleList() for layer in six.moves.range(n_layers): n_inputs = idim if layer == 0 else n_units self.prenet += [torch.nn.Sequential( torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU())]
[docs] def forward(self, x): """Calculate forward propagation. Args: x (Tensor): Batch of input tensors (B, *, idim). Returns: Tensor: Batch of output tensors (B, *, odim). """ for l in six.moves.range(len(self.prenet)): x = F.dropout(self.prenet[l](x), self.dropout_rate) return x
[docs]class Postnet(torch.nn.Module): """Postnet module for Spectrogram prediction network. This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet predicts refines the predicted Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True): """Initialize postnet module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. n_layers (int, optional): The number of layers. n_filts (int, optional): The number of filter size. n_units (int, optional): The number of filter channels. use_batch_norm (bool, optional): Whether to use batch normalization.. dropout_rate (float, optional): Dropout rate.. """ super(Postnet, self).__init__() self.postnet = torch.nn.ModuleList() for layer in six.moves.range(n_layers - 1): ichans = odim if layer == 0 else n_chans ochans = odim if layer == n_layers - 1 else n_chans if use_batch_norm: self.postnet += [torch.nn.Sequential( torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False), torch.nn.BatchNorm1d(ochans), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate))] else: self.postnet += [torch.nn.Sequential( torch.nn.Conv1d(ichans, ochans, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False), torch.nn.Tanh(), torch.nn.Dropout(dropout_rate))] ichans = n_chans if n_layers != 1 else odim if use_batch_norm: self.postnet += [torch.nn.Sequential( torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False), torch.nn.BatchNorm1d(odim), torch.nn.Dropout(dropout_rate))] else: self.postnet += [torch.nn.Sequential( torch.nn.Conv1d(ichans, odim, n_filts, stride=1, padding=(n_filts - 1) // 2, bias=False), torch.nn.Dropout(dropout_rate))]
[docs] def forward(self, xs): """Calculate forward propagation. Args: xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). Returns: Tensor: Batch of padded output tensor. (B, odim, Tmax). """ for l in six.moves.range(len(self.postnet)): xs = self.postnet[l](xs) return xs
[docs]class Decoder(torch.nn.Module): """Decoder module of Spectrogram prediction network. This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The decoder generates the sequence of features from the sequence of the hidden states. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__(self, idim, odim, att, dlayers=2, dunits=1024, prenet_layers=2, prenet_units=256, postnet_layers=5, postnet_chans=512, postnet_filts=5, output_activation_fn=None, cumulate_att_w=True, use_batch_norm=True, use_concate=True, dropout_rate=0.5, zoneout_rate=0.1, reduction_factor=1): """Initialize Tacotron2 decoder module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. att (torch.nn.Module): Instance of attention class. dlayers (int, optional): The number of decoder lstm layers. dunits (int, optional): The number of decoder lstm units. prenet_layers (int, optional): The number of prenet layers. prenet_units (int, optional): The number of prenet units. postnet_layers (int, optional): The number of postnet layers. postnet_filts (int, optional): The number of postnet filter size. postnet_chans (int, optional): The number of postnet filter channels. output_activation_fn (torch.nn.Module, optional): Activation function for outputs. cumulate_att_w (bool, optional): Whether to cumulate previous attention weight. use_batch_norm (bool, optional): Whether to use batch normalization. use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs. dropout_rate (float, optional): Dropout rate. zoneout_rate (float, optional): Zoneout rate. reduction_factor (int, optional): Reduction factor. """ super(Decoder, self).__init__() # store the hyperparameters self.idim = idim self.odim = odim self.att = att self.output_activation_fn = output_activation_fn self.cumulate_att_w = cumulate_att_w self.use_concate = use_concate self.reduction_factor = reduction_factor # check attention type if isinstance(self.att, AttForwardTA): self.use_att_extra_inputs = True else: self.use_att_extra_inputs = False # define lstm network prenet_units = prenet_units if prenet_layers != 0 else odim self.lstm = torch.nn.ModuleList() for layer in six.moves.range(dlayers): iunits = idim + prenet_units if layer == 0 else dunits lstm = torch.nn.LSTMCell(iunits, dunits) if zoneout_rate > 0.0: lstm = ZoneOutCell(lstm, zoneout_rate) self.lstm += [lstm] # define prenet if prenet_layers > 0: self.prenet = Prenet( idim=odim, n_layers=prenet_layers, n_units=prenet_units, dropout_rate=dropout_rate) else: self.prenet = None # define postnet if postnet_layers > 0: self.postnet = Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=dropout_rate) else: self.postnet = None # define projection layers iunits = idim + dunits if use_concate else dunits self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False) self.prob_out = torch.nn.Linear(iunits, reduction_factor) # initialize self.apply(decoder_init) def _zero_state(self, hs): init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) return init_hs
[docs] def forward(self, hs, hlens, ys): """Calculate forward propagation. Args: hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). hlens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim). Returns: Tensor: Batch of output tensors after postnet (B, Lmax, odim). Tensor: Batch of output tensors before postnet (B, Lmax, odim). Tensor: Batch of logits of stop prediction (B, Lmax). Tensor: Batch of attention weights (B, Lmax, Tmax). Note: This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1::self.reduction_factor] # length list should be list of int hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(hs.size(0), self.odim) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence outs, logits, att_ws = [], [], [] for y in ys.transpose(0, 1): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for l in six.moves.range(1, len(self.lstm)): z_list[l], c_list[l] = self.lstm[l]( z_list[l - 1], (z_list[l], c_list[l])) zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1] outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)] logits += [self.prob_out(zcs)] att_ws += [att_w] prev_out = y # teacher forcing if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w logits = torch.cat(logits, dim=1) # (B, Lmax) before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax) att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) if self.reduction_factor > 1: before_outs = before_outs.view(before_outs.size(0), self.odim, -1) # (B, odim, Lmax) if self.postnet is not None: after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax) else: after_outs = before_outs before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim) after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim) logits = logits # apply activation function for scaling if self.output_activation_fn is not None: before_outs = self.output_activation_fn(before_outs) after_outs = self.output_activation_fn(after_outs) return after_outs, before_outs, logits, att_ws
[docs] def inference(self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0): """Generate the sequence of features given the sequences of characters. Args: h (Tensor): Input sequence of encoder hidden states (T, C). threshold (float, optional): Threshold to stop generation. minlenratio (float, optional): Minimum length ratio. If set to 1.0 and the length of input is 10, the minimum length of outputs will be 10 * 1 = 10. minlenratio (float, optional): Minimum length ratio. If set to 10 and the length of input is 10, the maximum length of outputs will be 10 * 10 = 100. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). Note: This computation is performed in auto-regressive manner. """ # setup assert len(h.size()) == 2 hs = h.unsqueeze(0) ilens = [h.size(0)] maxlen = int(h.size(0) * maxlenratio) minlen = int(h.size(0) * minlenratio) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(1, self.odim) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence idx = 0 outs, att_ws, probs = [], [], [] while True: # updated index idx += self.reduction_factor # decoder calculation if self.use_att_extra_inputs: att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w) att_ws += [att_w] prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for l in six.moves.range(1, len(self.lstm)): z_list[l], c_list[l] = self.lstm[l]( z_list[l - 1], (z_list[l], c_list[l])) zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1] outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...] probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...] if self.output_activation_fn is not None: prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim) else: prev_out = outs[-1][:, :, -1] # (1, odim) if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w # check whether to finish generation if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: # check mininum length if idx < minlen: continue outs = torch.cat(outs, dim=2) # (1, odim, L) if self.postnet is not None: outs = outs + self.postnet(outs) # (1, odim, L) outs = outs.transpose(2, 1).squeeze(0) # (L, odim) probs = torch.cat(probs, dim=0) att_ws = torch.cat(att_ws, dim=0) break if self.output_activation_fn is not None: outs = self.output_activation_fn(outs) return outs, probs, att_ws
[docs] def calculate_all_attentions(self, hs, hlens, ys): """Calculate all of the attention weights. Args: hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). hlens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim). Returns: numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). Note: This computation is performed in teacher-forcing manner. """ # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) if self.reduction_factor > 1: ys = ys[:, self.reduction_factor - 1::self.reduction_factor] # length list should be list of int hlens = list(map(int, hlens)) # initialize hidden states of decoder c_list = [self._zero_state(hs)] z_list = [self._zero_state(hs)] for _ in six.moves.range(1, len(self.lstm)): c_list += [self._zero_state(hs)] z_list += [self._zero_state(hs)] prev_out = hs.new_zeros(hs.size(0), self.odim) # initialize attention prev_att_w = None self.att.reset() # loop for an output sequence att_ws = [] for y in ys.transpose(0, 1): if self.use_att_extra_inputs: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) else: att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) att_ws += [att_w] prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out xs = torch.cat([att_c, prenet_out], dim=1) z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) for l in six.moves.range(1, len(self.lstm)): z_list[l], c_list[l] = self.lstm[l]( z_list[l - 1], (z_list[l], c_list[l])) prev_out = y # teacher forcing if self.cumulate_att_w and prev_att_w is not None: prev_att_w = prev_att_w + att_w # Note: error when use += else: prev_att_w = att_w att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) return att_ws