1 /// copy ops 2 module grain.ops.copy; 3 4 import core.stdc.stdio; 5 6 import mir.format; 7 8 import grain.tensor : Tensor, Opt, isTensor; 9 import grain.ops.common : apply; 10 11 12 @nogc nothrow 13 auto contiguous(size_t dim, T, Storage)(Tensor!(dim, T, Storage) x) 14 { 15 if (x.isContiguous) return x; 16 return x.copy!Storage; 17 } 18 19 /// copy tensor between devices 20 struct Copy(size_t N, T, Src, Dst) 21 { 22 @nogc nothrow: 23 24 alias dsrc = Src.deviceof; 25 alias ddst = Dst.deviceof; 26 Opt opt, srcOpt; 27 28 Tensor!(N, T, Src) backward(Tensor!(N, T, Dst) gy) 29 { 30 static if (dsrc == ddst) 31 { 32 if (this.opt == this.srcOpt) 33 return gy; 34 else 35 return gy.copy!Src(this.srcOpt); 36 } 37 else 38 return gy.copy!Src(this.srcOpt); 39 } 40 } 41 42 43 Tensor!(N, T, Dst) 44 forward(size_t N, T, Src, Dst)( 45 ref Copy!(N, T, Src, Dst) self, 46 Tensor!(N, T, Src) x) 47 if (Src.deviceof == "cpu" && Dst.deviceof == "cpu") 48 { 49 self.srcOpt = x.opt; 50 auto y = typeof(return)(self.opt, x.shape); 51 y.lightScope[] = x.lightScope; 52 return y; 53 } 54 55 import grain.storage : DefaultCPUStorage; 56 57 /// ditto 58 @nogc nothrow 59 auto copy(Dst, size_t N, T, Src)(Tensor!(N, T, Src) x, Opt opt) 60 { 61 static if (Dst.deviceof != "cpu") 62 { 63 // update the default device id (-1 is CPU) 64 if (opt.deviceId == -1) opt.deviceId = 0; 65 } 66 Copy!(N, T, Src, Dst) f = {opt: opt}; 67 return f.apply(x); 68 } 69 70 /// ditto 71 auto copy(Dst, size_t N, T, Src)(Tensor!(N, T, Src) x) 72 { 73 return copy!Dst(x, x.opt); 74 } 75 76 77 78 /// ditto 79 auto copy(size_t N, T, Src)(Tensor!(N, T, Src) x) 80 { 81 return copy!Src(x, x.opt); 82 } 83 84 85 86 87 /// 88 @system @nogc 89 unittest 90 { 91 import grain.ops.transposed : transposed; 92 import mir.ndslice.topology : iota, as; 93 import grain.tensor; 94 import std.meta : AliasSeq; 95 import std.typecons : tuple; 96 97 static foreach (dtype; AliasSeq!(float, double, int, long)) 98 { 99 { 100 auto x = Tensor!(2, dtype)(2, 3); 101 x.asSlice[] = iota(x.shape).as!dtype; 102 auto y = x.copy; 103 assert(y.asSlice == x.asSlice); 104 x.asSlice[0, 0] = 1; 105 assert(y.asSlice != x.asSlice); 106 } 107 } 108 }