import torch
[docs]class NoamOpt(object):
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
return self.optimizer.param_groups
[docs] def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
[docs] def rate(self, step=None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * self.model_size ** (-0.5) \
* min(step ** (-0.5), step * self.warmup ** (-1.5))
[docs] def zero_grad(self):
self.optimizer.zero_grad()
[docs] def state_dict(self):
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict()
}
[docs] def load_state_dict(self, state_dict):
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
[docs]def get_std_opt(model, d_model, warmup, factor):
base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
return NoamOpt(d_model, factor, warmup, base)