Source code for espnet.optimizer.factory

"""Import optimizer class dynamically."""
import argparse

from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args


[docs]class OptimizerFactoryInterface: """Optimizer adaptor."""
[docs] @staticmethod def from_args(target, args: argparse.Namespace): """Initialize optimizer from argparse Namespace. Args: target: for pytorch `model.parameters()`, for chainer `model` args (argparse.Namespace): parsed command-line args """ raise NotImplementedError()
[docs] @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Register args.""" return parser
[docs] @classmethod def build(cls, target, **kwargs): """Initialize optimizer with python-level args. Args: target: for pytorch `model.parameters()`, for chainer `model` Returns: new Optimizer """ args = argparse.Namespace(**kwargs) args = fill_missing_args(args, cls.add_arguments) return cls.from_args(target, args)
[docs]def dynamic_import_optimizer(name: str, backend: str) -> OptimizerFactoryInterface: """Import optimizer class dynamically. Args: name (str): alias name or dynamic import syntax `module:class` backend (str): backend name e.g., chainer or pytorch Returns: OptimizerFactoryInterface or FunctionalOptimizerAdaptor """ if backend == "pytorch": from espnet.optimizer.pytorch import OPTIMIZER_FACTORY_DICT return OPTIMIZER_FACTORY_DICT[name] elif backend == "chainer": from espnet.optimizer.chainer import OPTIMIZER_FACTORY_DICT return OPTIMIZER_FACTORY_DICT[name] else: raise NotImplementedError(f"unsupported backend: {backend}") factory_class = dynamic_import(name) assert issubclass(factory_class, OptimizerFactoryInterface) return factory_class