Adam

struct Adam (
Chain
) {
Chain* target;
float lr;
float beta1;
float beta2;
float eps;
StateDict moment1;
StateDict moment2;
}

Constructors

this
this(ref Chain target, float lr, float eps = 1e-8)

Members

Functions

initStates
void initStates(string name, ref V field)
step
void step(string name, ref V field)

Examples

1 import grain.autograd;
2 import numir;
3 
4 {
5     auto model = MLP!(float, HostStorage)(3);
6     auto optim = Adam!(typeof(model))(model, 1e-3);
7     static assert(isOptimizer!(typeof(optim)));
8     model.fc1.weight.data.zero_();
9     model.fc1.weight.grad = [[0.2f, 0.0f, 0.0f], [0.0f, 0.0f, 0.0f]].variable
10         .data;
11     optim.update();
12     auto w = model.fc1.weight;
13     auto m1 = (1.0 - optim.beta1) * (0.2 - 0.0) + 0.0;
14     auto m2 = (1.0 - optim.beta2) * (0.2 * 0.2 - 0.0) + 0.0;
15     assert(approxEqual(w.sliced, [[-optim.lr * m1 / (m2 + optim.eps) ^^ 0.5,
16             0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
17     auto m1_ = optim.moment1[".fc1.weight"].to!(typeof(w));
18     assert(approxEqual(m1_.sliced, [[m1, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
19     auto m2_ = optim.moment2[".fc1.weight"].to!(typeof(w));
20     assert(approxEqual(m2_.sliced, [[m2, 0.0, 0.0], [0.0, 0.0, 0.0]].nparray));
21 }
22 version (grain_cuda) {
23     auto model = MLP!(float, DeviceStorage)(3);
24     auto optim = Adam!(typeof(model))(model, 0.1);
25     optim.update();
26 }

Meta