Compute negative log-likelihood: -logP(y=t)
test nll simple case, gradcheck and cpu/cuda equality
1 /++ equivalent torch v0.4 code 2 >>> x = torch.FloatTensor([[0.2, 0.4, 0.4], [0.1,0.5,0.4]]) 3 >>> x.requires_grad = True 4 >>> t = torch.LongTensor([1, 0]) 5 >>> l = torch.nn.functional.nll_loss(x, t) 6 >>> print(l) 7 tensor(-0.2500) 8 9 >>> l.backward() 10 >>> print(x.grad) 11 tensor([[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0]]) 12 +/ 13 import std.typecons; 14 import grain.testing; 15 16 NegativeLogLikelihood!(float, int) func; 17 auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]] 18 .variable; 19 auto ht = [1, 0, func.ignoreIndex].variable; 20 auto hl = func.forward(hx, ht); 21 assert(func._normalize == 0.5); 22 assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); 23 auto hgx = func.backward(1.0f.variable); 24 assert(hgx[0].sliced == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 0.0]]); 25 assert(!hgx[1].defined); 26 gradCheck(func, tuple(hx, ht), 1.0f.variable); 27 28 version (grain_cuda) { 29 auto dx = hx.to!DeviceStorage; 30 auto dt = ht.to!DeviceStorage; 31 auto dl = func.forward(dx, dt); 32 assert(func._normalize == 0.5); 33 assert(dl.to!HostStorage.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); 34 auto dgx = func.backward(1.0f.variable.to!DeviceStorage); 35 assert(dgx[0].to!HostStorage.sliced == 36 [[0.0, -0.5, 0.0], 37 [-0.5, 0.0, 0.0], 38 [0.0, 0.0, 0.0]]); 39 assert(!dgx[1].defined); 40 }
test variable.backward
1 import std.typecons; 2 import grain.testing; 3 import mir.ndslice; 4 static import grain.config; 5 6 grain.config.backprop = true; 7 8 NegativeLogLikelihood!(float, int) func; 9 auto hx = [[0.2f, 0.4f, 0.4f], [0.1f, 0.5f, 0.4f], [0.1f, 0.5f, 0.4f]] 10 .variable; 11 hx.requiresGrad = true; 12 auto ht = [1, 0, func.ignoreIndex].variable; 13 auto hl = func.applyForward(hx, ht); 14 15 assert(func._normalize == 0.5); 16 assert(hl.sliced == [-(0.4f + 0.1f + 0.0f) / 2]); 17 auto u = UntypedVariable(1.0f.variable); 18 hl.backward(&u); 19 20 assert(hx.grad[].sliced(3, 3) == [[0.0, -0.5, 0.0], [-0.5, 0.0, 0.0], [0.0, 0.0, 21 0.0]]); 22 // TODO assert(!ht.grad.defined);