1 module grain.testing;
2 
3 import mir.format : stringBuf, getData;
4 
5 /// assert all elements are equal
6 void assertEqual(
7     T1, T2,
8     string file = __FILE__,
9     size_t line = __LINE__,
10     string func = __FUNCTION__
11 )(T1 actual, T2 desired, string info = "none")
12 {
13     assert(actual == desired,
14            stringBuf()
15            << "(actual) " << actual
16            << " != (desired) " << desired
17            << ", (info) " <<  info
18            << " (func) " << func
19            << " (file) " << file
20            << " (line) " << line
21            << getData);
22 }
23 
24 
25 void assertShapeEqual(
26     T1, T2,
27     string file = __FILE__,
28     size_t line = __LINE__,
29     string func = __FUNCTION__
30 )(T1 actual, T2 desired, string info = "shape mismatch")
31 {
32     assert(actual.shape == desired.shape,
33            stringBuf()
34            << "(actual) " << actual.shape
35            << " != (desired) " << desired.shape
36            << ", (info) " <<  info
37            << " (func) " << func
38            << " (file) " << file
39            << " (line) " << line
40            << getData);
41 }
42 
43 
44 /// assert tensor elements are all close
45 void assertAllClose(
46     T1, T2,
47     string file = __FILE__,
48     size_t line = __LINE__,
49     string func = __FUNCTION__
50 )(
51     T1 actual,
52     T2 desired,
53     string msg = "",
54     double rtol = 1e-7,
55     double atol = 0,
56 )
57 {
58     import grain.tensor : isTensor;
59     import mir.ndslice : zip, reshape;
60     import std.math : abs;
61     assertShapeEqual(actual, desired);
62     int err;
63     static if (isTensor!T1)
64         auto a = actual.asSlice;
65     else
66         auto a = actual;
67     static if (isTensor!T2)
68         auto d = desired.asSlice;
69     else
70         auto d = desired;
71     
72     auto aflat = a.reshape([-1], err);
73     auto dflat = d.reshape([-1], err);
74     foreach (t; zip(aflat, dflat))
75     {
76         auto lhs = abs(t[0] - t[1]);
77         auto rhs = atol + rtol * abs(t[1]);
78         assert(lhs <= rhs,
79                stringBuf() << "ASSERT abs(a - b) <= atol + rtol * abs(b): "
80                << lhs << " > " << rhs
81                << " (func) " << func
82                << " (file) " << file
83                << " (line) " << line              
84                << getData);
85     }
86 }