Source code for espnet2.train.gan_trainer

# Copyright 2021 Tomoki Hayashi
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Trainer module for GAN-based training."""

import argparse
import dataclasses
import logging
import time

from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple

import torch

from typeguard import check_argument_types

from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler
from espnet2.schedulers.abs_scheduler import AbsScheduler
from espnet2.torch_utils.device_funcs import to_device
from espnet2.torch_utils.recursive_op import recursive_average
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.reporter import SubReporter
from espnet2.train.trainer import Trainer
from espnet2.train.trainer import TrainerOptions
from espnet2.utils.build_dataclass import build_dataclass
from espnet2.utils.types import str2bool

if torch.distributed.is_available():
    from torch.distributed import ReduceOp

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
    from torch.cuda.amp import GradScaler
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):  # NOQA
        yield

    GradScaler = None

try:
    import fairscale
except ImportError:
    fairscale = None


[docs]@dataclasses.dataclass class GANTrainerOptions(TrainerOptions): """Trainer option dataclass for GANTrainer.""" generator_first: bool
[docs]class GANTrainer(Trainer): """Trainer for GAN-based training. If you'd like to use this trainer, the model must inherit espnet.train.abs_gan_espnet_model.AbsGANESPnetModel. """
[docs] @classmethod def build_options(cls, args: argparse.Namespace) -> TrainerOptions: """Build options consumed by train(), eval(), and plot_attention().""" assert check_argument_types() return build_dataclass(GANTrainerOptions, args)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): """Add additional arguments for GAN-trainer.""" parser.add_argument( "--generator_first", type=str2bool, default=False, help="Whether to update generator first.", )
[docs] @classmethod def train_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], reporter: SubReporter, summary_writer, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> bool: """Train one epoch.""" assert check_argument_types() grad_noise = options.grad_noise accum_grad = options.accum_grad grad_clip = options.grad_clip grad_clip_type = options.grad_clip_type log_interval = options.log_interval no_forward_run = options.no_forward_run ngpu = options.ngpu use_wandb = options.use_wandb generator_first = options.generator_first distributed = distributed_option.distributed # Check unavailable options # TODO(kan-bayashi): Support the use of these options if accum_grad > 1: raise NotImplementedError( "accum_grad > 1 is not supported in GAN-based training." ) if grad_noise: raise NotImplementedError( "grad_noise is not supported in GAN-based training." ) if log_interval is None: try: log_interval = max(len(iterator) // 20, 10) except TypeError: log_interval = 100 model.train() all_steps_are_invalid = True # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") start_time = time.perf_counter() for iiter, (_, batch) in enumerate( reporter.measure_iter_time(iterator, "iter_time"), 1 ): assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: all_steps_are_invalid = False continue turn_start_time = time.perf_counter() if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: with autocast(scaler is not None): with reporter.measure_time(f"{turn}_forward_time"): retval = model(forward_generator=turn == "generator", **batch) # Note(kamo): # Supporting two patterns for the returned value from the model # a. dict type if isinstance(retval, dict): loss = retval["loss"] stats = retval["stats"] weight = retval["weight"] optim_idx = retval.get("optim_idx") if optim_idx is not None and not isinstance(optim_idx, int): if not isinstance(optim_idx, torch.Tensor): raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {type(optim_idx)}" ) if optim_idx.dim() >= 2: raise RuntimeError( "optim_idx must be int or 1dim torch.Tensor, " f"but got {optim_idx.dim()}dim tensor" ) if optim_idx.dim() == 1: for v in optim_idx: if v != optim_idx[0]: raise RuntimeError( "optim_idx must be 1dim tensor " "having same values for all entries" ) optim_idx = optim_idx[0].item() else: optim_idx = optim_idx.item() # b. tuple or list type else: raise RuntimeError("model output must be dict.") stats = {k: v for k, v in stats.items() if v is not None} if ngpu > 1 or distributed: # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) # Now weight is summation over all workers loss /= weight if distributed: # NOTE(kamo): Multiply world_size since DistributedDataParallel # automatically normalizes the gradient by world_size. loss *= torch.distributed.get_world_size() reporter.register(stats, weight) with reporter.measure_time(f"{turn}_backward_time"): if scaler is not None: # Scales loss. Calls backward() on scaled loss # to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose # for corresponding forward ops. scaler.scale(loss).backward() else: loss.backward() if scaler is not None: # Unscales the gradients of optimizer's assigned params in-place for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.unscale_(optimizer) # TODO(kan-bayashi): Compute grad norm without clipping grad_norm = None if grad_clip > 0.0: # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=grad_clip, norm_type=grad_clip_type, ) # PyTorch<=1.4, clip_grad_norm_ returns float value if not isinstance(grad_norm, torch.Tensor): grad_norm = torch.tensor(grad_norm) if grad_norm is None or torch.isfinite(grad_norm): all_steps_are_invalid = False with reporter.measure_time(f"{turn}_optim_step_time"): for iopt, (optimizer, scheduler) in enumerate( zip(optimizers, schedulers) ): if optim_idx is not None and iopt != optim_idx: continue if scaler is not None: # scaler.step() first unscales the gradients of # the optimizer's assigned params. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() else: optimizer.step() if isinstance(scheduler, AbsBatchStepScheduler): scheduler.step() else: logging.warning( f"The grad norm is {grad_norm}. " "Skipping updating the model." ) # Must invoke scaler.update() if unscale_() is used in the # iteration to avoid the following error: # RuntimeError: unscale_() has already been called # on this optimizer since the last update(). # Note that if the gradient has inf/nan values, # scaler.step skips optimizer.step(). if scaler is not None: for iopt, optimizer in enumerate(optimizers): if optim_idx is not None and iopt != optim_idx: continue scaler.step(optimizer) scaler.update() for iopt, optimizer in enumerate(optimizers): # NOTE(kan-bayashi): In the case of GAN, we need to clear # the gradient of both optimizers after every update. optimizer.zero_grad() # Register lr and train/load time[sec/step], # where step refers to accum_grad * mini-batch reporter.register( { f"optim{optim_idx}_lr{i}": pg["lr"] for i, pg in enumerate(optimizers[optim_idx].param_groups) if "lr" in pg }, ) reporter.register( {f"{turn}_train_time": time.perf_counter() - turn_start_time} ) turn_start_time = time.perf_counter() reporter.register({"train_time": time.perf_counter() - start_time}) start_time = time.perf_counter() # NOTE(kamo): Call log_message() after next() reporter.next() if iiter % log_interval == 0: logging.info(reporter.log_message(-log_interval)) if summary_writer is not None: reporter.tensorboard_add_scalar(summary_writer, -log_interval) if use_wandb: reporter.wandb_log() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) return all_steps_are_invalid
[docs] @classmethod @torch.no_grad() def validate_one_epoch( cls, model: torch.nn.Module, iterator: Iterable[Dict[str, torch.Tensor]], reporter: SubReporter, options: GANTrainerOptions, distributed_option: DistributedOption, ) -> None: """Validate one epoch.""" assert check_argument_types() ngpu = options.ngpu no_forward_run = options.no_forward_run distributed = distributed_option.distributed generator_first = options.generator_first model.eval() # [For distributed] Because iteration counts are not always equals between # processes, send stop-flag to the other processes if iterator is finished iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") for (_, batch) in iterator: assert isinstance(batch, dict), type(batch) if distributed: torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) if iterator_stop > 0: break batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") if no_forward_run: continue if generator_first: turns = ["generator", "discriminator"] else: turns = ["discriminator", "generator"] for turn in turns: retval = model(forward_generator=turn == "generator", **batch) if isinstance(retval, dict): stats = retval["stats"] weight = retval["weight"] else: _, stats, weight = retval if ngpu > 1 or distributed: # Apply weighted averaging for stats. # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed) reporter.register(stats, weight) reporter.next() else: if distributed: iterator_stop.fill_(1) torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)