1 /// matrix transpose ops 2 module grain.ops.transposed; 3 4 import grain.tensor : Tensor; 5 import grain.ops.common : apply; 6 7 8 /// transpose last two dim 9 struct Transposed(size_t N, T, Storage) 10 { 11 static assert(N >= 2); 12 13 static Tensor!(N, T, Storage) forward(Tensor!(N, T, Storage) x) 14 { 15 auto n1 = x.lengths[$-1]; 16 auto n2 = x.lengths[$-2]; 17 x.lengths[$-2] = n1; 18 x.lengths[$-1] = n2; 19 auto s1 = x.strides[$-1]; 20 auto s2 = x.strides[$-2]; 21 x.strides[$-1] = s2; 22 x.strides[$-2] = s1; 23 return x; 24 } 25 26 static Tensor!(N, T, Storage) backward(Tensor!(N, T, Storage) gy) 27 { 28 return gy.transposed; 29 } 30 } 31 32 /// ditto 33 auto transposed(size_t N, T, Storage)(Tensor!(N, T, Storage) x) 34 { 35 return Transposed!(N, T, Storage)().apply(x); 36 } 37 38 /// 39 @nogc unittest 40 { 41 import mir.ndslice.topology : iota; 42 43 // [0, 1, 2] 44 // [3, 4, 5] 45 // iter = p 46 // lengths = [2, 3] 47 // strides = [3, 1] 48 auto x = Tensor!(2, size_t)(2, 3); 49 x.asSlice[] = iota(2, 3); 50 51 // [0, 3] 52 // [1, 4] 53 // [2, 5] 54 // lengths = [3, 2] 55 // strides = [1, 3] 56 auto t = x.transposed; 57 assert(!t.isContiguous); 58 assert(t.lengths[1] == x.lengths[0]); 59 assert(t.lengths[0] == x.lengths[1]); 60 61 assert(t.asSlice[0, 0] == x.asSlice[0, 0]); 62 assert(t.asSlice[0, 1] == x.asSlice[1, 0]); 63 64 assert(t.asSlice[1, 0] == x.asSlice[0, 1]); 65 assert(t.asSlice[1, 1] == x.asSlice[1, 1]); 66 67 assert(t.asSlice[2, 0] == x.asSlice[0, 2]); 68 assert(t.asSlice[2, 1] == x.asSlice[1, 2]); 69 70 assert(t.transposed.asSlice == x.asSlice); 71 }