1 /// Tensor data structure module
2 module grain.tensor;
3 
4 import std.numeric : CustomFloat;
5 
6 import grain.storage : RCStorage, RCIter, DefaultCPUStorage;
7 debug import grain.testing : assertAllClose, assertEqual;
8 
9 /// IEEE 754-2008 half: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
10 alias half = CustomFloat!(10, 5);
11 /// bfloat16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
12 alias bfloat16 = CustomFloat!(7, 8);
13 
14 
15 struct Opt
16 {
17     int deviceId = 0;           // for CUDA and CL
18     int platformId = 0;         // for CL
19     bool requireGrad = false;   // autograd
20 
21     pure @nogc nothrow @safe
22     const(char)[] toString() const
23     {
24         import mir.format;
25         return (stringBuf()
26                 << "Opt("
27                 << "requireGrad=" << this.requireGrad
28                 << ", deviceId=" << this.deviceId
29                 << ", platformId=" << this.platformId
30                 << ")"
31                 << getData);
32     }
33 }
34 
35 // Tensor on CPU implementation
36 struct Tensor(size_t _dim, T, Storage = DefaultCPUStorage)
37 {
38     import mir.ndslice.slice : Slice, Universal, Structure;
39 
40     alias dim = _dim;
41     alias deviceof = Storage.deviceof;
42     alias shape = lengths;
43 
44     size_t[dim] lengths;
45     ptrdiff_t[dim] strides;
46     Storage payload;
47     ptrdiff_t offset = 0;
48 
49     Opt opt;
50     alias opt this;
51 
52     this(Opt opt, size_t[dim] lengths...)
53     {
54         this.opt = opt;
55         this(lengths);
56     }
57 
58     this(size_t[dim] lengths...)
59     {
60         import mir.ndslice.topology : iota;
61 
62         static if (deviceof == "cpu") this.deviceId = -1;
63         this.lengths = lengths;
64         this.strides = lengths.iota.strides;
65         auto al = typeof(Storage.init.allocator)(this.opt);
66         size_t n = T.sizeof * this.strides[0] * this.lengths[0];
67         this.payload = typeof(payload)(n, al);
68     }
69 
70     bool isContiguous() const
71     {
72         if (this.strides[dim - 1] != 1) return false;
73         foreach (i; 0 .. dim - 1)
74         {
75             if (this.strides[i] != this.lengths[i + 1]) return false;
76         }
77         return true;
78     }
79 
80     size_t numel() const
81     {
82         size_t ret = 1;
83         foreach (l; this.lengths) ret *= l;
84         return ret;
85     }
86 
87     RCIter!(T*, Storage) iterator() @property
88     {
89         static if (deviceof == "cuda")
90         {
91             import grain.dpp.cuda_runtime_api : cudaSetDevice;
92             cudaSetDevice(this.deviceId);
93         }
94         return payload.iterator!(T*) + offset;
95     }
96 
97     T* ptr()() scope return @property @trusted
98     {
99         return this.iterator.lightScope;
100     }
101 
102     static if (Storage.deviceof == "cpu")
103     {
104         Slice!(typeof(this.iterator()), dim, Universal) asSlice()()
105         {
106             import std.meta : AliasSeq;
107             alias structure = AliasSeq!(this.lengths, this.strides);
108             return typeof(return)(structure, this.iterator);
109         }
110 
111         Slice!(T*, dim, Universal) lightScope()() scope return @property @trusted
112         {
113             import std.meta : AliasSeq;
114             alias structure = AliasSeq!(this.lengths, this.strides);
115             return typeof(return)(structure, this.ptr);
116         }
117     }
118 }
119 
120 template isTensor(T)
121 {
122     static if (is(T : Tensor!(N, E, S), E, size_t N, S))
123         enum bool isTensor = true;
124     else
125         enum bool isTensor = false;
126 }
127 
128 
129 
130 @nogc unittest
131 {
132     auto x = Tensor!(2, double)(2, 3);
133     static assert(isTensor!(typeof(x)));
134     static assert(x.deviceof == "cpu");
135     assertEqual(x.strides[0], 3);
136     assertEqual(x.strides[1], 1);
137     assert(x.isContiguous);
138     assert(x.numel == 2 * 3);
139 }