1 /// CUDA device manager
2 module grain.cuda.device;
3 
4 import grain.cuda.testing : checkCuda;
5 import grain.dpp.cuda_runtime_api;
6 
7 @nogc nothrow:
8 
9 struct CuDevice
10 {
11     @nogc nothrow:
12     cudaStream_t stream = null;
13 
14     static CuDevice* devicesPtr;
15     static int devicesLength;
16 
17     static this()
18     {
19         debug import core.stdc.stdio : printf;
20 
21         scope (exit) cudaSetDevice(0);
22 
23         // init cuda devices
24         checkCuda(cudaGetDeviceCount(&devicesLength));
25         debug
26         {
27             printf("[grain.info]: device count %d\n", devicesLength);
28         }
29         import core.memory : pureMalloc;
30         devicesPtr = cast(CuDevice*) pureMalloc(CuDevice.sizeof * devicesLength);
31         CuDevice.devices[] = CuDevice.init;
32 
33         // init P2P
34         foreach (i; 0 .. count)
35         {
36             cudaSetDevice(i);
37             foreach (j; 0 .. count)
38             {
39                 if (i == j) continue;
40                 int ok = 0;
41                 checkCuda(cudaDeviceCanAccessPeer(&ok, i, j));
42                 if (!ok)
43                 {
44                     debug
45                     {
46                         printf("[grain.warn]: no GPU P2P: %d -> %d\n", i, j);
47                     }
48                     continue;
49                 }
50                 cudaDeviceEnablePeerAccess(j, 0);
51             }
52         }
53     }
54 
55     static ~this()
56     {
57         import core.memory : pureFree;
58         pureFree(devicesPtr);
59         devicesPtr = null;
60     }
61 
62     static count()
63     {
64         return devicesLength;
65     }
66 
67     static CuDevice[] devices()
68     {
69         return devicesPtr[0 .. devicesLength];
70     }
71 
72     static get(int index)
73     {
74         assert(index >= 0);
75         checkCuda(cudaSetDevice(index));
76         auto dev = devices[index];
77         if (dev.stream is null)
78         {
79             checkCuda(cudaStreamCreateWithFlags(&dev.stream, cudaStreamNonBlocking));
80         }
81         return dev;
82     }
83 }