1 /// copy ops
2 module grain.ops.copy;
3 
4 import core.stdc.stdio;
5 
6 import mir.format;
7 
8 import grain.tensor : Tensor, Opt, isTensor;
9 import grain.ops.common : apply;
10 
11 
12 @nogc nothrow
13 auto contiguous(size_t dim, T, Storage)(Tensor!(dim, T, Storage) x)
14 {
15     if (x.isContiguous) return x;
16     return x.copy!Storage;
17 }
18 
19 /// copy tensor between devices
20 struct Copy(size_t N, T, Src, Dst)
21 {
22     @nogc nothrow:
23     
24     alias dsrc = Src.deviceof;
25     alias ddst = Dst.deviceof;
26     Opt opt, srcOpt;
27 
28     Tensor!(N, T, Src) backward(Tensor!(N, T, Dst) gy)
29     {
30         static if (dsrc == ddst)
31         {
32             if (this.opt == this.srcOpt)
33                 return gy;
34             else
35                 return gy.copy!Src(this.srcOpt);
36         }
37         else
38             return gy.copy!Src(this.srcOpt);
39     }
40 }
41 
42 
43 Tensor!(N, T, Dst)
44 forward(size_t N, T, Src, Dst)(
45     ref Copy!(N, T, Src, Dst) self,
46     Tensor!(N, T, Src) x)
47 if (Src.deviceof == "cpu" && Dst.deviceof == "cpu")
48 {
49     self.srcOpt = x.opt;
50     auto y = typeof(return)(self.opt, x.shape);
51     y.lightScope[] = x.lightScope;
52     return y;
53 }
54 
55 import grain.storage : DefaultCPUStorage;
56 
57 /// ditto
58 @nogc nothrow
59 auto copy(Dst, size_t N, T, Src)(Tensor!(N, T, Src) x, Opt opt)
60 {
61     static if (Dst.deviceof != "cpu")
62     {
63         // update the default device id (-1 is CPU)
64         if (opt.deviceId == -1) opt.deviceId = 0;
65     }
66     Copy!(N, T, Src, Dst) f = {opt: opt};
67     return f.apply(x);
68 }
69 
70 /// ditto
71 auto copy(Dst, size_t N, T, Src)(Tensor!(N, T, Src) x)
72 {
73     return copy!Dst(x, x.opt);
74 }
75 
76 
77 
78 /// ditto
79 auto copy(size_t N, T, Src)(Tensor!(N, T, Src) x)
80 {
81     return copy!Src(x, x.opt);
82 }
83 
84 
85 
86 
87 ///
88 @system @nogc
89 unittest
90 {
91     import grain.ops.transposed : transposed;
92     import mir.ndslice.topology : iota, as;
93     import grain.tensor;
94     import std.meta : AliasSeq;
95     import std.typecons : tuple;
96     
97     static foreach (dtype; AliasSeq!(float, double, int, long))
98     {
99         {
100             auto x = Tensor!(2, dtype)(2, 3);
101             x.asSlice[] = iota(x.shape).as!dtype;
102             auto y = x.copy;
103             assert(y.asSlice == x.asSlice);
104             x.asSlice[0, 0] = 1;
105             assert(y.asSlice != x.asSlice);
106         }
107     }
108 }