1 /// Tensor data structure module 2 module grain.tensor; 3 4 import std.numeric : CustomFloat; 5 6 import grain.storage : RCStorage, RCIter, DefaultCPUStorage; 7 debug import grain.testing : assertAllClose, assertEqual; 8 9 /// IEEE 754-2008 half: https://en.wikipedia.org/wiki/Half-precision_floating-point_format 10 alias half = CustomFloat!(10, 5); 11 /// bfloat16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 12 alias bfloat16 = CustomFloat!(7, 8); 13 14 15 struct Opt 16 { 17 int deviceId = 0; // for CUDA and CL 18 int platformId = 0; // for CL 19 bool requireGrad = false; // autograd 20 21 pure @nogc nothrow @safe 22 const(char)[] toString() const 23 { 24 import mir.format; 25 return (stringBuf() 26 << "Opt(" 27 << "requireGrad=" << this.requireGrad 28 << ", deviceId=" << this.deviceId 29 << ", platformId=" << this.platformId 30 << ")" 31 << getData); 32 } 33 } 34 35 // Tensor on CPU implementation 36 struct Tensor(size_t _dim, T, Storage = DefaultCPUStorage) 37 { 38 import mir.ndslice.slice : Slice, Universal, Structure; 39 40 alias dim = _dim; 41 alias deviceof = Storage.deviceof; 42 alias shape = lengths; 43 44 size_t[dim] lengths; 45 ptrdiff_t[dim] strides; 46 Storage payload; 47 ptrdiff_t offset = 0; 48 49 Opt opt; 50 alias opt this; 51 52 this(Opt opt, size_t[dim] lengths...) 53 { 54 this.opt = opt; 55 this(lengths); 56 } 57 58 this(size_t[dim] lengths...) 59 { 60 import mir.ndslice.topology : iota; 61 62 static if (deviceof == "cpu") this.deviceId = -1; 63 this.lengths = lengths; 64 this.strides = lengths.iota.strides; 65 auto al = typeof(Storage.init.allocator)(this.opt); 66 size_t n = T.sizeof * this.strides[0] * this.lengths[0]; 67 this.payload = typeof(payload)(n, al); 68 } 69 70 bool isContiguous() const 71 { 72 if (this.strides[dim - 1] != 1) return false; 73 foreach (i; 0 .. dim - 1) 74 { 75 if (this.strides[i] != this.lengths[i + 1]) return false; 76 } 77 return true; 78 } 79 80 size_t numel() const 81 { 82 size_t ret = 1; 83 foreach (l; this.lengths) ret *= l; 84 return ret; 85 } 86 87 RCIter!(T*, Storage) iterator() @property 88 { 89 static if (deviceof == "cuda") 90 { 91 import grain.dpp.cuda_runtime_api : cudaSetDevice; 92 cudaSetDevice(this.deviceId); 93 } 94 return payload.iterator!(T*) + offset; 95 } 96 97 T* ptr()() scope return @property @trusted 98 { 99 return this.iterator.lightScope; 100 } 101 102 static if (Storage.deviceof == "cpu") 103 { 104 Slice!(typeof(this.iterator()), dim, Universal) asSlice()() 105 { 106 import std.meta : AliasSeq; 107 alias structure = AliasSeq!(this.lengths, this.strides); 108 return typeof(return)(structure, this.iterator); 109 } 110 111 Slice!(T*, dim, Universal) lightScope()() scope return @property @trusted 112 { 113 import std.meta : AliasSeq; 114 alias structure = AliasSeq!(this.lengths, this.strides); 115 return typeof(return)(structure, this.ptr); 116 } 117 } 118 } 119 120 template isTensor(T) 121 { 122 static if (is(T : Tensor!(N, E, S), E, size_t N, S)) 123 enum bool isTensor = true; 124 else 125 enum bool isTensor = false; 126 } 127 128 129 130 @nogc unittest 131 { 132 auto x = Tensor!(2, double)(2, 3); 133 static assert(isTensor!(typeof(x))); 134 static assert(x.deviceof == "cpu"); 135 assertEqual(x.strides[0], 3); 136 assertEqual(x.strides[1], 1); 137 assert(x.isContiguous); 138 assert(x.numel == 2 * 3); 139 }