1 /**
2 Metric (e.g., accuracy)
3 
4 TODO: perplexity, AUC, F1, BLEU, edit distance
5 */
6 module grain.metric;
7 
8 import grain.autograd : isVariable;
9 
10 
11 /// compute accuracy comparing prediction y (histgram) to target t (id)
12 auto accuracy(Vy, Vt)(Vy y, Vt t) if (isVariable!Vy && isVariable!Vt) {
13     import mir.ndslice : maxIndex;
14     import grain.autograd : to, HostStorage;
15 
16     auto nbatch = t.shape[0];
17     auto hy = y.to!HostStorage.sliced;
18     auto ht = t.to!HostStorage.sliced;
19     double acc = 0.0;
20     foreach (i; 0 .. nbatch) {
21         auto maxid = hy[i].maxIndex[0];
22         if (maxid == ht[i]) {
23             ++acc;
24         }
25     }
26     return acc / nbatch;
27 }