1 module grain.testing; 2 3 import mir.format : stringBuf, getData; 4 5 /// assert all elements are equal 6 void assertEqual( 7 T1, T2, 8 string file = __FILE__, 9 size_t line = __LINE__, 10 string func = __FUNCTION__ 11 )(T1 actual, T2 desired, string info = "none") 12 { 13 assert(actual == desired, 14 stringBuf() 15 << "(actual) " << actual 16 << " != (desired) " << desired 17 << ", (info) " << info 18 << " (func) " << func 19 << " (file) " << file 20 << " (line) " << line 21 << getData); 22 } 23 24 25 void assertShapeEqual( 26 T1, T2, 27 string file = __FILE__, 28 size_t line = __LINE__, 29 string func = __FUNCTION__ 30 )(T1 actual, T2 desired, string info = "shape mismatch") 31 { 32 assert(actual.shape == desired.shape, 33 stringBuf() 34 << "(actual) " << actual.shape 35 << " != (desired) " << desired.shape 36 << ", (info) " << info 37 << " (func) " << func 38 << " (file) " << file 39 << " (line) " << line 40 << getData); 41 } 42 43 44 /// assert tensor elements are all close 45 void assertAllClose( 46 T1, T2, 47 string file = __FILE__, 48 size_t line = __LINE__, 49 string func = __FUNCTION__ 50 )( 51 T1 actual, 52 T2 desired, 53 string msg = "", 54 double rtol = 1e-7, 55 double atol = 0, 56 ) 57 { 58 import grain.tensor : isTensor; 59 import mir.ndslice : zip, reshape; 60 import std.math : abs; 61 assertShapeEqual(actual, desired); 62 int err; 63 static if (isTensor!T1) 64 auto a = actual.asSlice; 65 else 66 auto a = actual; 67 static if (isTensor!T2) 68 auto d = desired.asSlice; 69 else 70 auto d = desired; 71 72 auto aflat = a.reshape([-1], err); 73 auto dflat = d.reshape([-1], err); 74 foreach (t; zip(aflat, dflat)) 75 { 76 auto lhs = abs(t[0] - t[1]); 77 auto rhs = atol + rtol * abs(t[1]); 78 assert(lhs <= rhs, 79 stringBuf() << "ASSERT abs(a - b) <= atol + rtol * abs(b): " 80 << lhs << " > " << rhs 81 << " (func) " << func 82 << " (file) " << file 83 << " (line) " << line 84 << getData); 85 } 86 }