1 /// High-level wrapper of cudnn library 2 module grain.cuda.cudnn; 3 4 import grain.tensor : Tensor; 5 import grain.cuda : cudnnHandle; 6 import grain.dpp.cudnn; 7 import grain.dpp.cuda_driver : CUdeviceptr; 8 import grain.cuda.testing : checkCudnn; 9 10 11 // TODO make shared 12 __gshared bool deterministic = false; 13 __gshared bool nanProp = true; 14 15 /// return global cudnn option 16 auto isDeterministic() 17 { 18 return deterministic ? CUDNN_DETERMINISTIC : CUDNN_NON_DETERMINISTIC; 19 } 20 21 /// ditto 22 auto isNanProp() 23 { 24 return nanProp ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN; 25 } 26 27 /// convert floating point types (float, double) into cudnn enum 28 template cudnnDataType(T, bool allowSameSize) 29 { 30 // TODO support half/int8 31 static if (is(T == ubyte)) 32 alias cudnnDataType = CUDNN_DATA_UINT8; 33 else static if (is(T == byte) || (allowSameSize && (T.sizeof == byte.sizeof))) 34 alias cudnnDataType = CUDNN_DATA_INT8; 35 else static if (//is(T == half) || 36 (allowSameSize && (T.sizeof == 16))) 37 alias cudnnDataType = CUDNN_DATA_HALF; 38 else static if (is(T == int)) 39 alias cudnnDataType = CUDNN_DATA_INT32; 40 else static if(is(T == float) || (allowSameSize && (T.sizeof == float.sizeof))) 41 alias cudnnDataType = CUDNN_DATA_FLOAT; 42 else static if(is(T == double) || (allowSameSize && (T.sizeof == double.sizeof))) 43 alias cudnnDataType = CUDNN_DATA_DOUBLE; 44 else 45 static assert(false, "unsupported type"); 46 } 47 48 /// cudnn data type of variable like struct 49 struct TensorDesc 50 { 51 cudnnTensorDescriptor_t desc; 52 CUdeviceptr ptr; 53 alias desc this; 54 55 /// no copy 56 @disable this(this); 57 /// no allocation on heap 58 @disable new(size_t); 59 60 nothrow @nogc ~this() 61 { 62 checkCudnn( cudnnDestroyTensorDescriptor(desc) ); 63 } 64 } 65 66 /// convert variable to cudnn tensor discriptor object 67 TensorDesc makeCudnnTensor(bool allowSameSize = false, T, size_t dim, Storage)(Tensor!(dim, T, Storage) x) 68 { 69 static assert(Storage.deviceof == "cuda"); 70 static assert(dim < CUDNN_DIM_MAX); 71 static if (dim < 4) 72 { 73 enum int ddim = 4; 74 int[ddim] shape; 75 int[ddim] strides; 76 shape[] = 1; 77 strides[] = 1; 78 foreach (d; 0 .. dim) 79 { 80 assert(x.shape[d] < int.max); 81 shape[d] = cast(int) x.shape[d]; 82 strides[d] = cast(int) x.strides[d]; 83 } 84 } else { 85 enum int ddim = cast(int) dim; 86 int[ddim] shape; 87 foreach (d; 0 .. dim) { 88 assert(x.shape[d] < int.max); 89 shape[d] = cast(int) x.shape[d]; 90 } 91 auto strides = x.strides; 92 } 93 94 auto ptr = cast(CUdeviceptr) x.ptr; 95 cudnnTensorDescriptor_t desc; 96 checkCudnn(cudnnCreateTensorDescriptor(&desc)); 97 checkCudnn(cudnnSetTensorNdDescriptor( 98 desc, 99 cudnnDataType!(T, allowSameSize), 100 ddim, 101 shape.ptr, 102 strides.ptr)); 103 return TensorDesc(desc, ptr); 104 } 105 106 /// 107 @system @nogc 108 unittest 109 { 110 import grain.cuda.allocator : CuTensor; 111 import grain.ops.transposed : transposed; 112 import grain.testing : assertEqual; 113 114 auto x = CuTensor!(3, float)(2, 3, 4).transposed; 115 auto t = x.makeCudnnTensor; 116 117 cudnnDataType_t dtype; 118 int dim; 119 int[3] shape; 120 int[3] strides; 121 cudnnGetTensorNdDescriptor(t.desc, 3, &dtype, &dim, shape.ptr, strides.ptr); 122 assert(dtype == CUDNN_DATA_FLOAT); 123 assertEqual(dim, 4, "dim < 4 will be 4"); 124 assert(shape == x.shape); 125 assert(strides == x.strides); 126 } 127 128 /// copy src to dst with broadcasting 129 void transform(T, size_t dim, Storage)( 130 Tensor!(dim, T, Storage) src, 131 ref Tensor!(dim, T, Storage) dst, 132 T alpha=1, 133 T beta=0 134 ) 135 { 136 static assert(Storage.deviceof == "cuda"); 137 // assert(is(T == float) || is(T == double), "unsupported type: " ~ T.stringof); 138 assert(src.shape == dst.shape); 139 checkCudnn( 140 cudnnTransformTensor( 141 cudnnHandle, 142 cast(const void*) &alpha, src.makeCudnnTensor!true, cast(const void*) src.ptr, 143 cast(const void*) &beta, dst.makeCudnnTensor!true, cast(void*) dst.ptr 144 ) ); 145 } 146 147 /// 148 @system @nogc nothrow 149 unittest 150 { 151 /// FIXME: int/long support 152 import grain.cuda.allocator : CuTensor; 153 import grain.ops.transposed : transposed; 154 auto x = CuTensor!(3, float)(2, 3, 4).transposed; 155 auto y = CuTensor!(3, float)(x.shape); 156 transform(x, y); 157 transform(y, x); 158 }