1 /// matrix transpose ops
2 module grain.ops.transposed;
3 
4 import grain.tensor : Tensor;
5 import grain.ops.common : apply;
6 
7 
8 /// transpose last two dim
9 struct Transposed(size_t N, T, Storage)
10 {
11     static assert(N >= 2);
12 
13     static Tensor!(N, T, Storage) forward(Tensor!(N, T, Storage) x)
14     {
15         auto n1 = x.lengths[$-1];
16         auto n2 = x.lengths[$-2];
17         x.lengths[$-2] = n1;
18         x.lengths[$-1] = n2;
19         auto s1 = x.strides[$-1];
20         auto s2 = x.strides[$-2];
21         x.strides[$-1] = s2;
22         x.strides[$-2] = s1;
23         return x;
24     }
25 
26     static Tensor!(N, T, Storage) backward(Tensor!(N, T, Storage) gy)
27     {
28         return gy.transposed;
29     }
30 }
31 
32 /// ditto
33 auto transposed(size_t N, T, Storage)(Tensor!(N, T, Storage) x)
34 {
35     return Transposed!(N, T, Storage)().apply(x);
36 }
37 
38 ///
39 @nogc unittest
40 {
41     import mir.ndslice.topology : iota;
42     
43     // [0, 1, 2]
44     // [3, 4, 5]
45     // iter   = p
46     // lengths  = [2, 3]
47     // strides = [3, 1]
48     auto x = Tensor!(2, size_t)(2, 3);
49     x.asSlice[] = iota(2, 3);
50 
51     // [0, 3]
52     // [1, 4]
53     // [2, 5]
54     // lengths  = [3, 2]
55     // strides = [1, 3]
56     auto t = x.transposed;
57     assert(!t.isContiguous);
58     assert(t.lengths[1] == x.lengths[0]);
59     assert(t.lengths[0] == x.lengths[1]);
60     
61     assert(t.asSlice[0, 0] == x.asSlice[0, 0]);
62     assert(t.asSlice[0, 1] == x.asSlice[1, 0]);
63 
64     assert(t.asSlice[1, 0] == x.asSlice[0, 1]);
65     assert(t.asSlice[1, 1] == x.asSlice[1, 1]);   
66 
67     assert(t.asSlice[2, 0] == x.asSlice[0, 2]);
68     assert(t.asSlice[2, 1] == x.asSlice[1, 2]);
69 
70     assert(t.transposed.asSlice == x.asSlice);
71 }