Source code for espnet.nets.pytorch_backend.e2e_tts_tacotron2

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2018 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Tacotron 2 related modules."""

import logging

from distutils.util import strtobool

import numpy as np
import torch
import torch.nn.functional as F

from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.rnn.attentions import AttForward
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA
from espnet.nets.pytorch_backend.rnn.attentions import AttLoc
from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG
from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHGLoss
from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder
from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.fill_missing_args import fill_missing_args


[docs]class GuidedAttentionLoss(torch.nn.Module): """Guided attention loss function module. This module calculates the guided attention loss described in `Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`_, which forces the attention to be diagonal. .. _`Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`: https://arxiv.org/abs/1710.08969 """ def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): """Initialize guided attention loss module. Args: sigma (float, optional): Standard deviation to control how close attention to a diagonal. alpha (float, optional): Scaling coefficient (lambda). reset_always (bool, optional): Whether to always reset masks. """ super(GuidedAttentionLoss, self).__init__() self.sigma = sigma self.alpha = alpha self.reset_always = reset_always self.guided_attn_masks = None self.masks = None def _reset_masks(self): self.guided_attn_masks = None self.masks = None
[docs] def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Args: att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). ilens (LongTensor): Batch of input lenghts (B,). olens (LongTensor): Batch of output lenghts (B,). Returns: Tensor: Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device) if self.masks is None: self.masks = self._make_masks(ilens, olens).to(att_ws.device) losses = self.guided_attn_masks * att_ws loss = torch.mean(losses.masked_select(self.masks)) if self.reset_always: self._reset_masks() return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens): n_batches = len(ilens) max_ilen = max(ilens) max_olen = max(olens) guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma) return guided_attn_masks @staticmethod def _make_guided_attention_mask(ilen, olen, sigma): """Make guided attention mask. Examples: >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) >>> guided_attn_mask.shape torch.Size([5, 5]) >>> guided_attn_mask tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) >>> guided_attn_mask.shape torch.Size([6, 3]) >>> guided_attn_mask tensor([[0.0000, 0.2934, 0.7506], [0.0831, 0.0831, 0.5422], [0.2934, 0.0000, 0.2934], [0.5422, 0.0831, 0.0831], [0.7506, 0.2934, 0.0000], [0.8858, 0.5422, 0.0831]]) """ grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2))) @staticmethod def _make_masks(ilens, olens): """Make masks indicating non-padded part. Examples: >>> ilens, olens = [5, 2], [8, 5] >>> _make_mask(ilens, olens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ in_masks = make_non_pad_mask(ilens) # (B, T_in) out_masks = make_non_pad_mask(olens) # (B, T_out) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
[docs]class Tacotron2Loss(torch.nn.Module): """Loss function module for Tacotron2.""" def __init__(self, use_masking=True, bce_pos_weight=20.0): """Initialize Tactoron2 loss module. Args: use_masking (bool): Whether to mask padded part in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token. """ super(Tacotron2Loss, self).__init__() self.use_masking = use_masking self.bce_pos_weight = bce_pos_weight
[docs] def forward(self, after_outs, before_outs, logits, ys, labels, olens): """Calculate forward propagation. Args: after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). logits (Tensor): Batch of stop logits (B, Lmax). ys (Tensor): Batch of padded target features (B, Lmax, odim). labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). olens (LongTensor): Batch of the lengths of each target (B,). Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ # perform masking for padded values if self.use_masking: mask = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) ys = ys.masked_select(mask) after_outs = after_outs.masked_select(mask) before_outs = before_outs.masked_select(mask) labels = labels.masked_select(mask[:, :, 0]) logits = logits.masked_select(mask[:, :, 0]) # calculate loss l1_loss = F.l1_loss(after_outs, ys) + F.l1_loss(before_outs, ys) mse_loss = F.mse_loss(after_outs, ys) + F.mse_loss(before_outs, ys) bce_loss = F.binary_cross_entropy_with_logits( logits, labels, pos_weight=torch.tensor(self.bce_pos_weight, device=ys.device)) return l1_loss, mse_loss, bce_loss
[docs]class Tacotron2(TTSInterface, torch.nn.Module): """Tacotron2 module for end-to-end text-to-speech (E2E-TTS). This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """
[docs] @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument('--embed-dim', default=512, type=int, help='Number of dimension of embedding') group.add_argument('--elayers', default=1, type=int, help='Number of encoder layers') group.add_argument('--eunits', '-u', default=512, type=int, help='Number of encoder hidden units') group.add_argument('--econv-layers', default=3, type=int, help='Number of encoder convolution layers') group.add_argument('--econv-chans', default=512, type=int, help='Number of encoder convolution channels') group.add_argument('--econv-filts', default=5, type=int, help='Filter size of encoder convolution') # attention group.add_argument('--atype', default="location", type=str, choices=["forward_ta", "forward", "location"], help='Type of attention mechanism') group.add_argument('--adim', default=512, type=int, help='Number of attention transformation dimensions') group.add_argument('--aconv-chans', default=32, type=int, help='Number of attention convolution channels') group.add_argument('--aconv-filts', default=15, type=int, help='Filter size of attention convolution') group.add_argument('--cumulate-att-w', default=True, type=strtobool, help="Whether or not to cumulate attention weights") # decoder group.add_argument('--dlayers', default=2, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=1024, type=int, help='Number of decoder hidden units') group.add_argument('--prenet-layers', default=2, type=int, help='Number of prenet layers') group.add_argument('--prenet-units', default=256, type=int, help='Number of prenet hidden units') group.add_argument('--postnet-layers', default=5, type=int, help='Number of postnet layers') group.add_argument('--postnet-chans', default=512, type=int, help='Number of postnet channels') group.add_argument('--postnet-filts', default=5, type=int, help='Filter size of postnet') group.add_argument('--output-activation', default=None, type=str, nargs='?', help='Output activation function') # cbhg group.add_argument('--use-cbhg', default=False, type=strtobool, help='Whether to use CBHG module') group.add_argument('--cbhg-conv-bank-layers', default=8, type=int, help='Number of convoluional bank layers in CBHG') group.add_argument('--cbhg-conv-bank-chans', default=128, type=int, help='Number of convoluional bank channles in CBHG') group.add_argument('--cbhg-conv-proj-filts', default=3, type=int, help='Filter size of convoluional projection layer in CBHG') group.add_argument('--cbhg-conv-proj-chans', default=256, type=int, help='Number of convoluional projection channels in CBHG') group.add_argument('--cbhg-highway-layers', default=4, type=int, help='Number of highway layers in CBHG') group.add_argument('--cbhg-highway-units', default=128, type=int, help='Number of highway units in CBHG') group.add_argument('--cbhg-gru-units', default=256, type=int, help='Number of GRU units in CBHG') # model (parameter) related group.add_argument('--use-batch-norm', default=True, type=strtobool, help='Whether to use batch normalization') group.add_argument('--use-concate', default=True, type=strtobool, help='Whether to concatenate encoder embedding with decoder outputs') group.add_argument('--use-residual', default=True, type=strtobool, help='Whether to use residual connection in conv layer') group.add_argument('--dropout-rate', default=0.5, type=float, help='Dropout rate') group.add_argument('--zoneout-rate', default=0.1, type=float, help='Zoneout rate') group.add_argument('--reduction-factor', default=1, type=int, help='Reduction factor') group.add_argument("--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions") group.add_argument("--spc-dim", default=None, type=int, help="Number of spectrogram dimensions") # loss related group.add_argument('--use-masking', default=False, type=strtobool, help='Whether to use masking in calculation of loss') group.add_argument('--bce-pos-weight', default=20.0, type=float, help='Positive sample weight in BCE calculation (only for use-masking=True)') group.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss") return parser
def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to mask padded part in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError('there is no such an activation function. (%s)' % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder(idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx) dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled in forward attention.") self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder(idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor) self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG(idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
[docs] def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). spcs (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens) loss = l1_loss + mse_loss + bce_loss report_keys = [ {'l1_loss': l1_loss.item()}, {'mse_loss': mse_loss.item()}, {'bce_loss': bce_loss.item()}, ] # caluculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss report_keys += [ {'attn_loss': attn_loss.item()}, ] # caluculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != spcs.shape[1]: spcs = spcs[:, :max_out] # caluculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ {'cbhg_l1_loss': cbhg_l1_loss.item()}, {'cbhg_mse_loss': cbhg_mse_loss.item()}, ] report_keys += [{'loss': loss.item()}] self.reporter.report(report_keys) return loss
[docs] def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # inference h = self.enc.inference(x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws
[docs] def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() return att_ws.cpu().numpy()
@property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss'] if self.use_guided_attn_loss: plot_keys += ['attn_loss'] if self.use_cbhg: plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss'] return plot_keys