1 /// matrix multiplication ops
2 module grain.ops.matmul;
3 
4 import std.typecons : Tuple, tuple;
5 
6 import grain.tensor : Tensor;
7 import grain.testing : assertEqual;
8 import grain.ops.common : apply;
9 
10 
11 /// matrix multiplication
12 struct Matmul(T, Storage) if (Storage.deviceof == "cpu")
13 {
14     alias Matrix = Tensor!(2, T, Storage);
15 
16     Matrix a, b;
17 
18     Matrix forward(Matrix a, Matrix b)
19     in
20     {
21         assertEqual(a.lengths[1], b.lengths[0], "Matmul lengths mismatch");
22     }
23     do
24     {
25         import mir.ndslice : as;
26         import mir.blas : gemm;
27 
28         auto c = Tensor!(2, T)(a.lengths[0], b.lengths[1]);
29         gemm(cast(T) 1, a.lightScope, b.lightScope, cast(T) 0, c.lightScope);
30         return c;
31     }
32 
33     Tuple!(Matrix, Matrix) backward(Matrix gc)
34     {
35         import grain.ops.transposed : transposed;
36 
37         auto ga = matmul(gc, this.b.transposed);
38         auto gb = matmul(this.a.transposed, gc);
39         return tuple(ga, gb);
40     }
41 }
42 
43 /// ditto
44 auto matmul(T, Storage)(Tensor!(2, T, Storage) a, Tensor!(2, T, Storage) b)
45 {
46     Matmul!(T, Storage) mm;
47     return mm.apply(a, b);
48 }
49 
50 
51 @system @nogc
52 unittest
53 {
54     import grain.random : normal_;
55     import grain.testing : assertAllClose;
56     
57     auto x = Tensor!(2, double)(2, 3).normal_;
58     auto y = Tensor!(2, double)(3, 2).normal_;
59     auto z = x.matmul(y);
60     auto c = Tensor!(2, double)(2, 2);
61     c.asSlice[] = 0;
62     foreach (i; 0 .. x.lengths[0])
63     {
64         foreach (j; 0 .. y.lengths[1])
65         {
66             foreach (k; 0 .. x.lengths[1])
67             {
68                 c.asSlice[i, j] += x.asSlice[i, k] * y.asSlice[k, j];
69             }
70         }
71     }
72     assertAllClose(x.matmul(y), c);
73 }
74