1 /** 2 Metric (e.g., accuracy) 3 4 TODO: perplexity, AUC, F1, BLEU, edit distance 5 */ 6 module grain.metric; 7 8 import grain.autograd : isVariable; 9 10 11 /// compute accuracy comparing prediction y (histgram) to target t (id) 12 auto accuracy(Vy, Vt)(Vy y, Vt t) if (isVariable!Vy && isVariable!Vt) { 13 import mir.ndslice : maxIndex; 14 import grain.autograd : to, HostStorage; 15 16 auto nbatch = t.shape[0]; 17 auto hy = y.to!HostStorage.sliced; 18 auto ht = t.to!HostStorage.sliced; 19 double acc = 0.0; 20 foreach (i; 0 .. nbatch) { 21 auto maxid = hy[i].maxIndex[0]; 22 if (maxid == ht[i]) { 23 ++acc; 24 } 25 } 26 return acc / nbatch; 27 }