1 /// CUDA module
2 module grain.cuda;
3 
4 import grain.dpp.cublas : cublasHandle_t, cublasCreate_v2, cublasDestroy_v2;
5 import grain.dpp.cudnn : cudnnHandle_t, cudnnCreate, cudnnDestroy;
6 
7 public import grain.cuda.allocator;
8 public import grain.cuda.compiler;
9 public import grain.cuda.cudnn;
10 public import grain.cuda.device;
11 public import grain.cuda.ops;
12 public import grain.cuda.testing;
13 
14 __gshared cublasHandle_t cublasHandle;
15 __gshared cudnnHandle_t cudnnHandle;
16 
17 
18 /// global cuda init
19 @nogc shared static this()
20 {
21     // init CUDA libraries
22     checkCublas(cublasCreate_v2(&cublasHandle));
23     checkCudnn( cudnnCreate(&cudnnHandle) );
24 }
25 
26 
27 /// global cuda exit
28 @nogc shared static ~this()
29 {
30     cublasDestroy_v2(cublasHandle);
31     checkCudnn( cudnnDestroy(cudnnHandle) );
32 }