1 /// matrix multiplication ops 2 module grain.ops.matmul; 3 4 import std.typecons : Tuple, tuple; 5 6 import grain.tensor : Tensor; 7 import grain.testing : assertEqual; 8 import grain.ops.common : apply; 9 10 11 /// matrix multiplication 12 struct Matmul(T, Storage) if (Storage.deviceof == "cpu") 13 { 14 alias Matrix = Tensor!(2, T, Storage); 15 16 Matrix a, b; 17 18 Matrix forward(Matrix a, Matrix b) 19 in 20 { 21 assertEqual(a.lengths[1], b.lengths[0], "Matmul lengths mismatch"); 22 } 23 do 24 { 25 import mir.ndslice : as; 26 import mir.blas : gemm; 27 28 auto c = Tensor!(2, T)(a.lengths[0], b.lengths[1]); 29 gemm(cast(T) 1, a.lightScope, b.lightScope, cast(T) 0, c.lightScope); 30 return c; 31 } 32 33 Tuple!(Matrix, Matrix) backward(Matrix gc) 34 { 35 import grain.ops.transposed : transposed; 36 37 auto ga = matmul(gc, this.b.transposed); 38 auto gb = matmul(this.a.transposed, gc); 39 return tuple(ga, gb); 40 } 41 } 42 43 /// ditto 44 auto matmul(T, Storage)(Tensor!(2, T, Storage) a, Tensor!(2, T, Storage) b) 45 { 46 Matmul!(T, Storage) mm; 47 return mm.apply(a, b); 48 } 49 50 51 @system @nogc 52 unittest 53 { 54 import grain.random : normal_; 55 import grain.testing : assertAllClose; 56 57 auto x = Tensor!(2, double)(2, 3).normal_; 58 auto y = Tensor!(2, double)(3, 2).normal_; 59 auto z = x.matmul(y); 60 auto c = Tensor!(2, double)(2, 2); 61 c.asSlice[] = 0; 62 foreach (i; 0 .. x.lengths[0]) 63 { 64 foreach (j; 0 .. y.lengths[1]) 65 { 66 foreach (k; 0 .. x.lengths[1]) 67 { 68 c.asSlice[i, j] += x.asSlice[i, k] * y.asSlice[k, j]; 69 } 70 } 71 } 72 assertAllClose(x.matmul(y), c); 73 } 74