Source code for espnet.transform.transformation

from collections import OrderedDict
import copy
import io
import logging
import sys

import yaml

from espnet.utils.dynamic_import import dynamic_import


PY2 = sys.version_info[0] == 2

if PY2:
    from collections import Sequence
    from funcsigs import signature
else:
    # The ABCs from 'collections' will stop working in 3.8
    from collections.abc import Sequence
    from inspect import signature


# TODO(karita): inherit TransformInterface
# TODO(karita): register cmd arguments in asr_train.py
import_alias = dict(
    identity='espnet.transform.transform_interface:Identity',
    time_warp='espnet.transform.spec_augment:TimeWarp',
    time_mask='espnet.transform.spec_augment:TimeMask',
    freq_mask='espnet.transform.spec_augment:FreqMask',
    spec_augment='espnet.transform.spec_augment:SpecAugment',
    speed_perturbation='espnet.transform.perturb:SpeedPerturbation',
    volume_perturbation='espnet.transform.perturb:VolumePerturbation',
    noise_injection='espnet.transform.perturb:NoiseInjection',
    bandpass_perturbation='espnet.transform.perturb:BandpassPerturbation',
    rir_convolve='espnet.transform.perturb:RIRConvolve',
    delta='espnet.transform.add_deltas:AddDeltas',
    cmvn='espnet.transform.cmvn:CMVN',
    utterance_cmvn='espnet.transform.cmvn:UtteranceCMVN',
    fbank='espnet.transform.spectrogram:LogMelSpectrogram',
    spectrogram='espnet.transform.spectrogram:Spectrogram',
    stft='espnet.transform.spectrogram:Stft',
    istft='espnet.transform.spectrogram:IStft',
    stft2fbank='espnet.transform.spectrogram:Stft2LogMelSpectrogram',
    wpe='espnet.transform.wpe:WPE',
    channel_selector='espnet.transform.channel_selector:ChannelSelector')


[docs]class Transformation(object): """Apply some functions to the mini-batch Examples: >>> kwargs = {"process": [{"type": "fbank", ... "n_mels": 80, ... "fs": 16000}, ... {"type": "cmvn", ... "stats": "data/train/cmvn.ark", ... "norm_vars": True}, ... {"type": "delta", "window": 2, "order": 2}]} >>> transform = Transformation(kwargs) >>> bs = 10 >>> xs = [np.random.randn(100, 80).astype(np.float32) ... for _ in range(bs)] >>> xs = transform(xs) """ def __init__(self, conffile=None): if conffile is not None: if isinstance(conffile, dict): self.conf = copy.deepcopy(conffile) else: with io.open(conffile, encoding='utf-8') as f: self.conf = yaml.safe_load(f) assert isinstance(self.conf, dict), type(self.conf) else: self.conf = {'mode': 'sequential', 'process': []} self.functions = OrderedDict() if self.conf.get('mode', 'sequential') == 'sequential': for idx, process in enumerate(self.conf['process']): assert isinstance(process, dict), type(process) opts = dict(process) process_type = opts.pop('type') class_obj = dynamic_import(process_type, import_alias) # TODO(karita): assert issubclass(class_obj, TransformInterface) try: self.functions[idx] = class_obj(**opts) except TypeError: try: signa = signature(class_obj) except ValueError: # Some function, e.g. built-in function, are failed pass else: logging.error('Expected signature: {}({})' .format(class_obj.__name__, signa)) raise else: raise NotImplementedError( 'Not supporting mode={}'.format(self.conf['mode'])) def __repr__(self): rep = '\n' + '\n'.join( ' {}: {}'.format(k, v) for k, v in self.functions.items()) return '{}({})'.format(self.__class__.__name__, rep) def __call__(self, xs, uttid_list=None, **kwargs): """Return new mini-batch :param Union[Sequence[np.ndarray], np.ndarray] xs: :param Union[Sequence[str], str] uttid_list: :return: batch: :rtype: List[np.ndarray] """ if not isinstance(xs, Sequence): is_batch = False xs = [xs] else: is_batch = True if isinstance(uttid_list, str): uttid_list = [uttid_list for _ in range(len(xs))] if self.conf.get('mode', 'sequential') == 'sequential': for idx in range(len(self.conf['process'])): func = self.functions[idx] # TODO(karita): use TrainingTrans and UttTrans to check __call__ args # Derive only the args which the func has try: param = signature(func).parameters except ValueError: # Some function, e.g. built-in function, are failed param = {} _kwargs = {k: v for k, v in kwargs.items() if k in param} try: if uttid_list is not None and 'uttid' in param: xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] else: xs = [func(x, **_kwargs) for x in xs] except Exception: logging.fatal('Catch a exception from {}th func: {}' .format(idx, func)) raise else: raise NotImplementedError( 'Not supporting mode={}'.format(self.conf['mode'])) if is_batch: return xs else: return xs[0]