Source code for espnet.nets.pytorch_backend.transformer.optimizer

import torch


[docs]class NoamOpt(object): "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer): self.optimizer = optimizer self._step = 0 self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 @property def param_groups(self): return self.optimizer.param_groups
[docs] def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step()
[docs] def rate(self, step=None): "Implement `lrate` above" if step is None: step = self._step return self.factor * self.model_size ** (-0.5) \ * min(step ** (-0.5), step * self.warmup ** (-1.5))
[docs] def zero_grad(self): self.optimizer.zero_grad()
[docs] def state_dict(self): return { "_step": self._step, "warmup": self.warmup, "factor": self.factor, "model_size": self.model_size, "_rate": self._rate, "optimizer": self.optimizer.state_dict() }
[docs] def load_state_dict(self, state_dict): for key, value in state_dict.items(): if key == "optimizer": self.optimizer.load_state_dict(state_dict["optimizer"]) else: setattr(self, key, value)
[docs]def get_std_opt(model, d_model, warmup, factor): base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9) return NoamOpt(d_model, factor, warmup, base)