Source code for espnet.asr.asr_utils

#!/usr/bin/env python

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import argparse
import copy
import json
import logging
# matplotlib related
import os
import shutil
import tempfile

# chainer related
import chainer

from chainer import training
from chainer.training import extension

from chainer.serializers.npz import DictionarySerializer
from chainer.serializers.npz import NpzDeserializer

# io related
import matplotlib
import numpy as np
import torch
matplotlib.use('Agg')


# * -------------------- training iterator related -------------------- *


[docs]class CompareValueTrigger(object): """Trigger invoked when key value getting bigger or lower than before. Args: key (str) : Key of value. compare_fn ((float, float) -> bool) : Function to compare the values. trigger (tuple(int, str)) : Trigger that decide the comparison interval. """ def __init__(self, key, compare_fn, trigger=(1, 'epoch')): self._key = key self._best_value = None self._interval_trigger = training.util.get_trigger(trigger) self._init_summary() self._compare_fn = compare_fn def __call__(self, trainer): """Get value related to the key and compare with current value.""" observation = trainer.observation summary = self._summary key = self._key if key in observation: summary.add({key: observation[key]}) if not self._interval_trigger(trainer): return False stats = summary.compute_mean() value = float(stats[key]) # copy to CPU self._init_summary() if self._best_value is None: # initialize best value self._best_value = value return False elif self._compare_fn(self._best_value, value): return True else: self._best_value = value return False def _init_summary(self): self._summary = chainer.reporter.DictSummary()
[docs]class PlotAttentionReport(extension.Extension): """Plot attention reporter. Args: att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): Function of attention visualization. data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. outdir (str): Directory to save figures. converter (espnet.asr.*_backend.asr.CustomConverter): Function to convert data. device (int | torch.device): Device. reverse (bool): If True, input and output length are reversed. ikey (str): Key to access input (for ASR ikey="input", for MT ikey="output".) iaxis (int): Dimension to access input (for ASR iaxis=0, for MT iaxis=1.) okey (str): Key to access output (for ASR okey="input", MT okay="output".) """ def __init__(self, att_vis_fn, data, outdir, converter, transform, device, reverse=False, ikey="input", iaxis=0, okey="output", oaxis=0): self.att_vis_fn = att_vis_fn self.data = copy.deepcopy(data) self.outdir = outdir self.converter = converter self.transform = transform self.device = device self.reverse = reverse self.ikey = ikey self.iaxis = iaxis self.okey = okey self.oaxis = oaxis if not os.path.exists(self.outdir): os.makedirs(self.outdir) def __call__(self, trainer): """Plot and save image file of att_ws matrix.""" att_ws = self.get_attention_weights() for idx, att_w in enumerate(att_ws): filename = "%s/%s.ep.{.updater.epoch}.png" % ( self.outdir, self.data[idx][0]) att_w = self.get_attention_weight(idx, att_w) self._plot_and_save_attention(att_w, filename.format(trainer))
[docs] def log_attentions(self, logger, step): """Add image files of att_ws matrix to the tensorboard.""" att_ws = self.get_attention_weights() for idx, att_w in enumerate(att_ws): att_w = self.get_attention_weight(idx, att_w) plot = self.draw_attention_plot(att_w) logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step) plot.clf()
[docs] def get_attention_weights(self): """Return attention weights. Returns: numpy.ndarray: attention weights.float. Its shape would be differ from backend. * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) other case => (B, Lmax, Tmax). * chainer-> (B, Lmax, Tmax) """ batch = self.converter([self.transform(self.data)], self.device) if isinstance(batch, tuple): att_ws = self.att_vis_fn(*batch) else: att_ws = self.att_vis_fn(**batch) return att_ws
[docs] def get_attention_weight(self, idx, att_w): """Transform attention matrix with regard to self.reverse.""" if self.reverse: dec_len = int(self.data[idx][1][self.ikey][self.iaxis]['shape'][0]) enc_len = int(self.data[idx][1][self.okey][self.oaxis]['shape'][0]) else: dec_len = int(self.data[idx][1][self.okey][self.oaxis]['shape'][0]) enc_len = int(self.data[idx][1][self.ikey][self.iaxis]['shape'][0]) if len(att_w.shape) == 3: att_w = att_w[:, :dec_len, :enc_len] else: att_w = att_w[:dec_len, :enc_len] return att_w
[docs] def draw_attention_plot(self, att_w): """Plot the att_w matrix. Returns: matplotlib.pyplot: pyplot object with attention matrix image. """ import matplotlib.pyplot as plt att_w = att_w.astype(np.float32) if len(att_w.shape) == 3: for h, aw in enumerate(att_w, 1): plt.subplot(1, len(att_w), h) plt.imshow(aw, aspect="auto") plt.xlabel("Encoder Index") plt.ylabel("Decoder Index") else: plt.imshow(att_w, aspect="auto") plt.xlabel("Encoder Index") plt.ylabel("Decoder Index") plt.tight_layout() return plt
def _plot_and_save_attention(self, att_w, filename): plt = self.draw_attention_plot(att_w) plt.savefig(filename) plt.close()
[docs]def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz): """Extension to restore snapshot. Returns: An extension function. """ @training.make_extension(trigger=(1, 'epoch')) def restore_snapshot(trainer): _restore_snapshot(model, snapshot, load_fn) return restore_snapshot
def _restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz): load_fn(snapshot, model) logging.info('restored from ' + str(snapshot))
[docs]def adadelta_eps_decay(eps_decay): """Extension to perform adadelta eps decay. Args: eps_decay (float): Decay rate of eps. Returns: An extension function. """ @training.make_extension(trigger=(1, 'epoch')) def adadelta_eps_decay(trainer): _adadelta_eps_decay(trainer, eps_decay) return adadelta_eps_decay
def _adadelta_eps_decay(trainer, eps_decay): optimizer = trainer.updater.get_optimizer('main') # for chainer if hasattr(optimizer, 'eps'): current_eps = optimizer.eps setattr(optimizer, 'eps', current_eps * eps_decay) logging.info('adadelta eps decayed to ' + str(optimizer.eps)) # pytorch else: for p in optimizer.param_groups: p["eps"] *= eps_decay logging.info('adadelta eps decayed to ' + str(p["eps"]))
[docs]def torch_snapshot(savefun=torch.save, filename='snapshot.ep.{.updater.epoch}'): """Extension to take snapshot of the trainer for pytorch. Returns: An extension function. """ @extension.make_extension(trigger=(1, 'epoch'), priority=-100) def torch_snapshot(trainer): _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun) return torch_snapshot
def _torch_snapshot_object(trainer, target, filename, savefun): # make snapshot_dict dictionary s = DictionarySerializer() s.save(trainer) if hasattr(trainer.updater.model, "model"): # (for TTS) if hasattr(trainer.updater.model.model, "module"): model_state_dict = trainer.updater.model.model.module.state_dict() else: model_state_dict = trainer.updater.model.model.state_dict() else: # (for ASR) if hasattr(trainer.updater.model, "module"): model_state_dict = trainer.updater.model.module.state_dict() else: model_state_dict = trainer.updater.model.state_dict() snapshot_dict = { "trainer": s.target, "model": model_state_dict, "optimizer": trainer.updater.get_optimizer('main').state_dict() } # save snapshot dictionary fn = filename.format(trainer) prefix = 'tmp' + fn tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out) tmppath = os.path.join(tmpdir, fn) try: savefun(snapshot_dict, tmppath) shutil.move(tmppath, os.path.join(trainer.out, fn)) finally: shutil.rmtree(tmpdir)
[docs]def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55): """Adds noise from a standard normal distribution to the gradients. The standard deviation (`sigma`) is controlled by the three hyper-parameters below. `sigma` goes to zero (no noise) with more iterations. Args: model (torch.nn.model): Model. iteration (int): Number of iterations. duration (int) {100, 1000}: Number of durations to control the interval of the `sigma` change. eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`. scale_factor (float) {0.55}: The scale of `sigma`. """ interval = (iteration // duration) + 1 sigma = eta / interval ** scale_factor for param in model.parameters(): if param.grad is not None: _shape = param.grad.size() noise = sigma * torch.randn(_shape).to(param.device) param.grad += noise
# * -------------------- general -------------------- *
[docs]def get_model_conf(model_path, conf_path=None): """Get model config information by reading a model config file (model.json). Args: model_path (str): Model path. conf_path (str): Optional model config path. Returns: list[int, int, dict[str, Any]]: Config information loaded from json file. """ if conf_path is None: model_conf = os.path.dirname(model_path) + '/model.json' else: model_conf = conf_path with open(model_conf, "rb") as f: logging.info('reading a config file from ' + model_conf) confs = json.load(f) if isinstance(confs, dict): # for lm args = confs return argparse.Namespace(**args) else: # for asr, tts, mt idim, odim, args = confs return idim, odim, argparse.Namespace(**args)
[docs]def chainer_load(path, model): """Load chainer model parameters. Args: path (str): Model path or snapshot file path to be loaded. model (chainer.Chain): Chainer model. """ if 'snapshot' in path: chainer.serializers.load_npz(path, model, path='updater/model:main/') else: chainer.serializers.load_npz(path, model)
[docs]def torch_save(path, model): """Save torch model states. Args: path (str): Model path to be saved. model (torch.nn.Module): Torch model. """ if hasattr(model, 'module'): torch.save(model.module.state_dict(), path) else: torch.save(model.state_dict(), path)
[docs]def snapshot_object(target, filename): """Returns a trainer extension to take snapshots of a given object. Args: target (model): Object to serialize. filename (str): Name of the file into which the object is serialized.It can be a format string, where the trainer object is passed to the :meth: `str.format` method. For example, ``'snapshot_{.updater.iteration}'`` is converted to ``'snapshot_10000'`` at the 10,000th iteration. Returns: An extension function. """ @extension.make_extension(trigger=(1, 'epoch'), priority=-100) def snapshot_object(trainer): torch_save(os.path.join(trainer.out, filename.format(trainer)), target) return snapshot_object
[docs]def torch_load(path, model): """Load torch model states. Args: path (str): Model path or snapshot file path to be loaded. model (torch.nn.Module): Torch model. """ if 'snapshot' in path: model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)['model'] else: model_state_dict = torch.load(path, map_location=lambda storage, loc: storage) if hasattr(model, 'module'): model.module.load_state_dict(model_state_dict) else: model.load_state_dict(model_state_dict) del model_state_dict
[docs]def torch_resume(snapshot_path, trainer): """Resume from snapshot for pytorch. Args: snapshot_path (str): Snapshot file path. trainer (chainer.training.Trainer): Chainer's trainer instance. """ # load snapshot snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) # restore trainer states d = NpzDeserializer(snapshot_dict['trainer']) d.load(trainer) # restore model states if hasattr(trainer.updater.model, "model"): # (for TTS model) if hasattr(trainer.updater.model.model, "module"): trainer.updater.model.model.module.load_state_dict(snapshot_dict['model']) else: trainer.updater.model.model.load_state_dict(snapshot_dict['model']) else: # (for ASR model) if hasattr(trainer.updater.model, "module"): trainer.updater.model.module.load_state_dict(snapshot_dict['model']) else: trainer.updater.model.load_state_dict(snapshot_dict['model']) # retore optimizer states trainer.updater.get_optimizer('main').load_state_dict(snapshot_dict['optimizer']) # delete opened snapshot del snapshot_dict
# * ------------------ recognition related ------------------ *
[docs]def parse_hypothesis(hyp, char_list): """Parse hypothesis. Args: hyp (list[dict[str, Any]]): Recognition hypothesis. char_list (list[str]): List of characters. Returns: tuple(str, str, str, float) """ # remove sos and get results tokenid_as_list = list(map(int, hyp['yseq'][1:])) token_as_list = [char_list[idx] for idx in tokenid_as_list] score = float(hyp['score']) # convert to string tokenid = " ".join([str(idx) for idx in tokenid_as_list]) token = " ".join(token_as_list) text = "".join(token_as_list).replace('<space>', ' ') return text, token, tokenid, score
[docs]def add_results_to_json(js, nbest_hyps, char_list): """Add N-best results to json. Args: js (dict[str, Any]): Groundtruth utterance dict. nbest_hyps_sd (list[dict[str, Any]]): List of hypothesis for multi_speakers: nutts x nspkrs. char_list (list[str]): List of characters. Returns: dict[str, Any]: N-best results added utterance dict. """ # copy old json info new_js = dict() new_js['utt2spk'] = js['utt2spk'] new_js['output'] = [] for n, hyp in enumerate(nbest_hyps, 1): # parse hypothesis rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) # copy ground-truth if len(js['output']) > 0: out_dic = dict(js['output'][0].items()) else: # for no reference case (e.g., speech translation) out_dic = {'name': ''} # update name out_dic['name'] += '[%d]' % n # add recognition results out_dic['rec_text'] = rec_text out_dic['rec_token'] = rec_token out_dic['rec_tokenid'] = rec_tokenid out_dic['score'] = score # add to list of N-best result dicts new_js['output'].append(out_dic) # show 1-best result if n == 1: if 'text' in out_dic.keys(): logging.info('groundtruth: %s' % out_dic['text']) logging.info('prediction : %s' % out_dic['rec_text']) return new_js
[docs]def plot_spectrogram(plt, spec, mode='db', fs=None, frame_shift=None, bottom=True, left=True, right=True, top=False, labelbottom=True, labelleft=True, labelright=True, labeltop=False, cmap='inferno'): """Plot spectrogram using matplotlib. Args: plt (matplotlib.pyplot): pyplot object. spec (numpy.ndarray): Input stft (Freq, Time) mode (str): db or linear. fs (int): Sample frequency. To convert y-axis to kHz unit. frame_shift (int): The frame shift of stft. To convert x-axis to second unit. bottom (bool):Whether to draw the respective ticks. left (bool): right (bool): top (bool): labelbottom (bool):Whether to draw the respective tick labels. labelleft (bool): labelright (bool): labeltop (bool): cmap (str): Colormap defined in matplotlib. """ spec = np.abs(spec) if mode == 'db': x = 20 * np.log10(spec + np.finfo(spec.dtype).eps) elif mode == 'linear': x = spec else: raise ValueError(mode) if fs is not None: ytop = fs / 2000 ylabel = 'kHz' else: ytop = x.shape[0] ylabel = 'bin' if frame_shift is not None and fs is not None: xtop = x.shape[1] * frame_shift / fs xlabel = 's' else: xtop = x.shape[1] xlabel = 'frame' extent = (0, xtop, 0, ytop) plt.imshow(x[::-1], cmap=cmap, extent=extent) if labelbottom: plt.xlabel('time [{}]'.format(xlabel)) if labelleft: plt.ylabel('freq [{}]'.format(ylabel)) plt.colorbar().set_label('{}'.format(mode)) plt.tick_params(bottom=bottom, left=left, right=right, top=top, labelbottom=labelbottom, labelleft=labelleft, labelright=labelright, labeltop=labeltop) plt.axis('auto')