Source code for espnet.lm.pytorch_backend.lm

#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py

"""LM training in pytorch."""

import copy
import json
import logging
import numpy as np

import torch
import torch.nn as nn
from torch.nn.parallel import data_parallel

from chainer import Chain
from chainer.dataset import convert
from chainer import reporter
from chainer import training
from chainer.training import extensions

from espnet.lm.lm_utils import count_tokens
from espnet.lm.lm_utils import load_dataset
from espnet.lm.lm_utils import MakeSymlinkToBestModel
from espnet.lm.lm_utils import ParallelSentenceIterator
from espnet.lm.lm_utils import read_tokens
from espnet.nets.lm_interface import dynamic_import_lm
from espnet.nets.lm_interface import LMInterface
from espnet.optimizer.factory import dynamic_import_optimizer
from espnet.scheduler.pytorch import PyTorchScheduler
from espnet.scheduler.scheduler import dynamic_import_scheduler

from espnet.asr.asr_utils import snapshot_object
from espnet.asr.asr_utils import torch_load
from espnet.asr.asr_utils import torch_resume
from espnet.asr.asr_utils import torch_snapshot

from espnet.utils.training.tensorboard_logger import TensorboardLogger

from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.train_utils import check_early_stop
from espnet.utils.training.train_utils import set_early_stop


[docs]def compute_perplexity(result): """Compute and add the perplexity to the LogReport. :param dict result: The current observations """ # Routine to rewrite the result dictionary of LogReport to add perplexity values result["perplexity"] = np.exp(result["main/nll"] / result["main/count"]) if "validation/main/nll" in result: result["val_perplexity"] = np.exp( result["validation/main/nll"] / result["validation/main/count"] )
[docs]class Reporter(Chain): """Dummy module to use chainer's trainer."""
[docs] def report(self, loss): """Report nothing.""" pass
[docs]def concat_examples(batch, device=None, padding=None): """Concat examples in minibatch. :param np.ndarray batch: The batch to concatenate :param int device: The device to send to :param Tuple[int,int] padding: The padding to use :return: (inputs, targets) :rtype (torch.Tensor, torch.Tensor) """ x, t = convert.concat_examples(batch, padding=padding) x = torch.from_numpy(x) t = torch.from_numpy(t) if device is not None and device >= 0: x = x.cuda(device) t = t.cuda(device) return x, t
[docs]class BPTTUpdater(training.StandardUpdater): """An updater for a pytorch LM.""" def __init__( self, train_iter, model, optimizer, schedulers, device, gradclip=None, use_apex=False, accum_grad=1, ): """Initialize class. Args: train_iter (chainer.dataset.Iterator): The train iterator model (LMInterface) : The model to update optimizer (torch.optim.Optimizer): The optimizer for training schedulers (espnet.scheduler.scheduler.SchedulerInterface): The schedulers of `optimizer` device (int): The device id gradclip (float): The gradient clipping value to use use_apex (bool): The flag to use Apex in backprop. accum_grad (int): The number of gradient accumulation. """ super(BPTTUpdater, self).__init__(train_iter, optimizer) self.model = model self.device = device self.gradclip = gradclip self.use_apex = use_apex self.scheduler = PyTorchScheduler(schedulers, optimizer) self.accum_grad = accum_grad # The core part of the update routine can be customized by overriding.
[docs] def update_core(self): """Update the model.""" # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator("main") optimizer = self.get_optimizer("main") # Progress the dataset iterator for sentences at each iteration. self.model.zero_grad() # Clear the parameter gradients accum = {"loss": 0.0, "nll": 0.0, "count": 0} for _ in range(self.accum_grad): batch = train_iter.__next__() # Concatenate the token IDs to matrices and send them to the device # self.converter does this job # (it is chainer.dataset.concat_examples by default) x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: loss, nll, count = self.model(x, t) else: # apex does not support torch.nn.DataParallel loss, nll, count = data_parallel(self.model, (x, t), self.device) # backward loss = loss.mean() / self.accum_grad if self.use_apex: from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # Backprop # accumulate stats accum["loss"] += float(loss) accum["nll"] += float(nll.sum()) accum["count"] += int(count.sum()) for k, v in accum.items(): reporter.report({k: v}, optimizer.target) if self.gradclip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) optimizer.step() # Update the parameters self.scheduler.step(n_iter=self.iteration)
[docs]class LMEvaluator(BaseEvaluator): """A custom evaluator for a pytorch LM.""" def __init__(self, val_iter, eval_model, reporter, device): """Initialize class. :param chainer.dataset.Iterator val_iter : The validation iterator :param LMInterface eval_model : The model to evaluate :param chainer.Reporter reporter : The observations reporter :param int device : The device id to use """ super(LMEvaluator, self).__init__(val_iter, reporter, device=-1) self.model = eval_model self.device = device
[docs] def evaluate(self): """Evaluate the model.""" val_iter = self.get_iterator("main") loss = 0 nll = 0 count = 0 self.model.eval() with torch.no_grad(): for batch in copy.copy(val_iter): x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) if self.device[0] == -1: l, n, c = self.model(x, t) else: # apex does not support torch.nn.DataParallel l, n, c = data_parallel(self.model, (x, t), self.device) loss += float(l.sum()) nll += float(n.sum()) count += int(c.sum()) self.model.train() # report validation loss observation = {} with reporter.report_scope(observation): reporter.report({"loss": loss}, self.model.reporter) reporter.report({"nll": nll}, self.model.reporter) reporter.report({"count": count}, self.model.reporter) return observation
[docs]def train(args): """Train with the given args. :param Namespace args: The program arguments :param type model_class: LMInterface class for training """ model_class = dynamic_import_lm(args.model_module, args.backend) assert issubclass(model_class, LMInterface), "model should implement LMInterface" # display torch version logging.info("torch version = " + torch.__version__) set_deterministic_pytorch(args) # check cuda and cudnn availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get special label ids unk = args.char_list_dict["<unk>"] eos = args.char_list_dict["<eos>"] # read tokens as a sequence of sentences val, n_val_tokens, n_val_oovs = load_dataset( args.valid_label, args.char_list_dict, args.dump_hdf5_path ) train, n_train_tokens, n_train_oovs = load_dataset( args.train_label, args.char_list_dict, args.dump_hdf5_path ) logging.info("#vocab = " + str(args.n_vocab)) logging.info("#sentences in the training data = " + str(len(train))) logging.info("#tokens in the training data = " + str(n_train_tokens)) logging.info( "oov rate in the training data = %.2f %%" % (n_train_oovs / n_train_tokens * 100) ) logging.info("#sentences in the validation data = " + str(len(val))) logging.info("#tokens in the validation data = " + str(n_val_tokens)) logging.info( "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100) ) use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # Create the dataset iterators batch_size = args.batchsize * max(args.ngpu, 1) if batch_size * args.accum_grad > args.batchsize: logging.info( f"batch size is automatically increased " f"({args.batchsize} -> {batch_size * args.accum_grad})" ) train_iter = ParallelSentenceIterator( train, batch_size, max_length=args.maxlen, sos=eos, eos=eos, shuffle=not use_sortagrad, ) val_iter = ParallelSentenceIterator( val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False ) epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) logging.info("#iterations per epoch = %d" % epoch_iters) logging.info("#total iterations = " + str(args.epoch * epoch_iters)) # Prepare an RNNLM model if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model_class(args.n_vocab, args).to(dtype=dtype) if args.ngpu > 0: model.to("cuda") gpu_id = list(range(args.ngpu)) else: gpu_id = [-1] # Save model conf to json model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode( "utf_8" ) ) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), ) ) # Set up an optimizer opt_class = dynamic_import_optimizer(args.opt, args.backend) optimizer = opt_class.from_args(model.parameters(), args) if args.schedulers is None: schedulers = [] else: schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers] # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux" ) raise e model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK reporter = Reporter() setattr(model, "reporter", reporter) setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) updater = BPTTUpdater( train_iter, model, optimizer, schedulers, gpu_id, gradclip=args.gradclip, use_apex=use_apex, accum_grad=args.accum_grad, ) trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir) trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) trainer.extend( extensions.LogReport( postprocess=compute_perplexity, trigger=(args.report_interval_iters, "iteration"), ) ) trainer.extend( extensions.PrintReport( [ "epoch", "iteration", "main/loss", "perplexity", "val_perplexity", "elapsed_time", ] ), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) # Save best models trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}")) trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}")) # T.Hori: MinValueTrigger should be used, but it fails when resuming trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model")) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"), ) if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) set_early_stop(trainer, args, is_lm=True) if args.tensorboard_dir is not None and args.tensorboard_dir != "": from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(args.tensorboard_dir) trainer.extend( TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration") ) trainer.run() check_early_stop(trainer, args.epoch) # compute perplexity for test set if args.test_label: logging.info("test the best model") torch_load(args.outdir + "/rnnlm.model.best", model) test = read_tokens(args.test_label, args.char_list_dict) n_test_tokens, n_test_oovs = count_tokens(test, unk) logging.info("#sentences in the test data = " + str(len(test))) logging.info("#tokens in the test data = " + str(n_test_tokens)) logging.info( "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100) ) test_iter = ParallelSentenceIterator( test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False ) evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) result = evaluator() compute_perplexity(result) logging.info(f"test perplexity: {result['perplexity']}")