#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 decoder related modules."""
import six
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA
[docs]def decoder_init(m):
"""Initialize decoder parameters."""
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain('tanh'))
[docs]class ZoneOutCell(torch.nn.Module):
"""ZoneOut Cell module.
This is a module of zoneout described in `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_.
This code is modified from `eladhoffer/seq2seq.pytorch`_.
Examples:
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`:
https://arxiv.org/abs/1606.01305
.. _`eladhoffer/seq2seq.pytorch`:
https://github.com/eladhoffer/seq2seq.pytorch
"""
def __init__(self, cell, zoneout_rate=0.1):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super(ZoneOutCell, self).__init__()
self.cell = cell
self.hidden_size = cell.hidden_size
self.zoneout_rate = zoneout_rate
if zoneout_rate > 1.0 or zoneout_rate < 0.0:
raise ValueError("zoneout probability must be in the range from 0.0 to 1.0.")
[docs] def forward(self, inputs, hidden):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden = self.cell(inputs, hidden)
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate)
return next_hidden
def _zoneout(self, h, next_h, prob):
# apply recursively
if isinstance(h, tuple):
num_h = len(h)
if not isinstance(prob, tuple):
prob = tuple([prob] * num_h)
return tuple([self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)])
if self.training:
mask = h.new(*h.size()).bernoulli_(prob)
return mask * h + (1 - mask) * next_h
else:
return prob * h + (1 - prob) * next_h
[docs]class Prenet(torch.nn.Module):
"""Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network, which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions.
Note:
This module alway applies dropout even in evaluation See the detail in _`Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super(Prenet, self).__init__()
self.dropout_rate = dropout_rate
self.prenet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers):
n_inputs = idim if layer == 0 else n_units
self.prenet += [torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_units),
torch.nn.ReLU())]
[docs] def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, *, idim).
Returns:
Tensor: Batch of output tensors (B, *, odim).
"""
for l in six.moves.range(len(self.prenet)):
x = F.dropout(self.prenet[l](x), self.dropout_rate)
return x
[docs]class Postnet(torch.nn.Module):
"""Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet predicts refines the predicted
Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super(Postnet, self).__init__()
self.postnet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers - 1):
ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(ochans),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
ichans = n_chans if n_layers != 1 else odim
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(odim),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Dropout(dropout_rate))]
[docs] def forward(self, xs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
Returns:
Tensor: Batch of padded output tensor. (B, odim, Tmax).
"""
for l in six.moves.range(len(self.postnet)):
xs = self.postnet[l](xs)
return xs
[docs]class Decoder(torch.nn.Module):
"""Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The decoder generates the sequence of
features from the sequence of the hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, odim, att,
dlayers=2,
dunits=1024,
prenet_layers=2,
prenet_units=256,
postnet_layers=5,
postnet_chans=512,
postnet_filts=5,
output_activation_fn=None,
cumulate_att_w=True,
use_batch_norm=True,
use_concate=True,
dropout_rate=0.5,
zoneout_rate=0.1,
reduction_factor=1):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional): Activation function for outputs.
cumulate_att_w (bool, optional): Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super(Decoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.odim = odim
self.att = att
self.output_activation_fn = output_activation_fn
self.cumulate_att_w = cumulate_att_w
self.use_concate = use_concate
self.reduction_factor = reduction_factor
# check attention type
if isinstance(self.att, AttForwardTA):
self.use_att_extra_inputs = True
else:
self.use_att_extra_inputs = False
# define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = torch.nn.ModuleList()
for layer in six.moves.range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits
lstm = torch.nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0:
lstm = ZoneOutCell(lstm, zoneout_rate)
self.lstm += [lstm]
# define prenet
if prenet_layers > 0:
self.prenet = Prenet(
idim=odim,
n_layers=prenet_layers,
n_units=prenet_units,
dropout_rate=dropout_rate)
else:
self.prenet = None
# define postnet
if postnet_layers > 0:
self.postnet = Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=dropout_rate)
else:
self.postnet = None
# define projection layers
iunits = idim + dunits if use_concate else dunits
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False)
self.prob_out = torch.nn.Linear(iunits, reduction_factor)
# initialize
self.apply(decoder_init)
def _zero_state(self, hs):
init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
return init_hs
[docs] def forward(self, hs, hlens, ys):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
outs, logits, att_ws = [], [], []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)]
logits += [self.prob_out(zcs)]
att_ws += [att_w]
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
logits = torch.cat(logits, dim=1) # (B, Lmax)
before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax)
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
if self.reduction_factor > 1:
before_outs = before_outs.view(before_outs.size(0), self.odim, -1) # (B, odim, Lmax)
if self.postnet is not None:
after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax)
else:
after_outs = before_outs
before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim)
after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim)
logits = logits
# apply activation function for scaling
if self.output_activation_fn is not None:
before_outs = self.output_activation_fn(before_outs)
after_outs = self.output_activation_fn(after_outs)
return after_outs, before_outs, logits, att_ws
[docs] def inference(self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio. If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio. If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
"""
# setup
assert len(h.size()) == 2
hs = h.unsqueeze(0)
ilens = [h.size(0)]
maxlen = int(h.size(0) * maxlenratio)
minlen = int(h.size(0) * minlenratio)
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(1, self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
while True:
# updated index
idx += self.reduction_factor
# decoder calculation
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...]
probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...]
if self.output_activation_fn is not None:
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim)
else:
prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=2) # (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
att_ws = torch.cat(att_ws, dim=0)
break
if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs)
return outs, probs, att_ws
[docs] def calculate_all_attentions(self, hs, hlens, ys):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
att_ws = []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
return att_ws