# -*- coding: utf-8 -*-
# Copyright 2018 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""TTS Interface realted modules."""
import chainer
[docs]class Reporter(chainer.Chain):
"""Reporter module."""
[docs] def report(self, dicts):
"""Report values from a given dict."""
for d in dicts:
chainer.reporter.report(d, self)
[docs]class TTSInterface(object):
"""TTS Interface for ESPnet model implementation."""
[docs] @staticmethod
def add_arguments(parser):
"""Add model specific argments to parser."""
return parser
def __init__(self):
"""Initilize TTS module."""
self.reporter = Reporter()
[docs] def forward(self, *args, **kwargs):
"""Calculate TTS forward propagation.
Returns:
Tensor: Loss value.
"""
raise NotImplementedError("forward method is not implemented")
[docs] def inference(self, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Returns:
Tensor: The sequence of generated features (L, odim).
Tensor: The sequence of stop probabilities (L,).
Tensor: The sequence of attention weights (L, T).
"""
raise NotImplementedError("inference method is not implemented")
[docs] def calculate_all_attentions(self, *args, **kwargs):
"""Calculate TTS attention weights.
Args:
Tensor: Batch of attention weights (B, Lmax, Tmax).
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
@property
def attention_plot_class(self):
"""Plot attention weights."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
@property
def base_plot_keys(self):
"""Return base key names to plot during training.
The keys should match what `chainer.reporter` reports.
if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.
Returns:
list[str]: Base keys to plot during training.
"""
return ['loss']