1 module grain.cuda.compiler;
2 
3 import grain.storage : RCString;
4 import grain.cuda.testing : checkNvrtc, checkCuda;
5 import grain.dpp.cuda_driver : CUmodule, CUfunction;
6 
7 struct CompileOpt
8 {
9     import grain.dpp.cuda_driver : CUjit_option;
10 
11     string[] headerSources;
12     string[] headerNames;
13     string[] options;
14     CUjit_option[] jitOptions;
15     void** jitOptionValues;
16 
17     int deviceId = 0;
18 
19     int numHeaders() const pure nothrow @nogc @safe
20     {
21         assert(headerSources.length == headerNames.length);
22         return cast(int) headerSources.length;
23     }
24 
25     int numOptions() const pure nothrow @nogc @safe
26     {
27         return cast(int) options.length;
28     }
29 }
30 
31 /** compile kernel string (input) to PTX
32 
33  TODO type check, better default option
34 
35  * \brief   nvrtcCreateProgram creates an instance of nvrtcProgram with the
36  *          given input parameters, and sets the output parameter \p prog with
37  *          it.
38  *
39  * \param   [in]  src          CUDA program source.
40  * \param   [in]  name         CUDA program name.\n
41  *                             \p name can be \c NULL; \c "default_program" is
42  *                             used when \p name is \c NULL.
43  *
44  */
45 nothrow @nogc
46 CUmodule compileModule(
47     scope const(char)[] src,
48     scope const(char)[] name = "",
49     CompileOpt opt = CompileOpt.init)
50 {
51     import std..string : fromStringz;
52     import core.memory : pureMalloc, pureFree;
53     import grain.dpp.nvrtc;
54     import grain.dpp.cuda_driver;
55 
56     nvrtcProgram prog;
57     auto nh = opt.numHeaders;
58     alias P = immutable(char)*;
59     auto hss = cast(P*) pureMalloc(P.sizeof * opt.numHeaders);
60     scope (exit) pureFree(hss);
61     auto hns = cast(P*) pureMalloc(P.sizeof * opt.numHeaders);
62     scope (exit) pureFree(hns);
63     foreach (i; 0 .. opt.numHeaders)
64     {
65         hss[i] = opt.headerSources[i].ptr;
66         hns[i] = opt.headerNames[i].ptr;
67     }
68     checkNvrtc(nvrtcCreateProgram(&prog, src.ptr, name.ptr, opt.numHeaders, hss, hns));
69     scope (exit) checkNvrtc(nvrtcDestroyProgram(&prog));
70 
71     // compile PTX
72     auto opts = cast(P*) pureMalloc(P.sizeof * opt.numOptions);
73     scope (exit) pureFree(opts);
74     foreach (i; 0 .. opt.numOptions)
75     {
76         opts[i] = opt.options[i].ptr;
77     }
78     nvrtcResult res = nvrtcCompileProgram(prog, opt.numOptions, opts);
79 
80     // dump log
81     size_t logSize;
82     checkNvrtc(nvrtcGetProgramLogSize(prog, &logSize));
83     char *log = cast(char*) pureMalloc(char.sizeof * (logSize + 1));
84     log[logSize] = '\0';
85     scope (exit) pureFree(log);
86     checkNvrtc(res, log[0 .. logSize]);
87 
88     // fetch PTX
89     size_t ptxSize;
90     checkNvrtc(nvrtcGetPTXSize(prog, &ptxSize));
91     char *ptx = cast(char*) pureMalloc(char.sizeof * ptxSize);
92     scope (exit) pureFree(ptx);
93     checkNvrtc(nvrtcGetPTX(prog, ptx));
94 
95     // load PTX
96     CUmodule m;
97     checkCuda(cuModuleLoadDataEx(
98         &m, ptx,
99         cast(int) opt.jitOptions.length,
100         opt.jitOptions.ptr,
101         opt.jitOptionValues));
102     return m;
103 }
104 
105 
106 /// runtime function compiler
107 /// TODO type check version (pick up d-nv impl)
108 @nogc nothrow
109 CUfunction compile(
110     scope string name, scope string args, scope string proc,
111     CompileOpt opt = CompileOpt.init
112 ) {
113     import grain.dpp.cuda_driver : cuModuleGetFunction;
114     import mir.format : stringBuf, getData;
115     enum attr = q{extern "C" __global__ void };
116     auto src = stringBuf()
117                << attr
118                << name << "(" << args << ") {\n"
119                << proc
120                << "\n}"
121                << getData;
122     auto m = compileModule(src, name, opt);
123     CUfunction kernel;
124     checkCuda(cuModuleGetFunction(&kernel, m, name.ptr));
125     return kernel;
126 }
127 
128 ///
129 @system nothrow
130 unittest
131 {
132     import grain.testing : assertAllClose;
133     import grain.tensor : Tensor, Opt;
134     import grain.cuda : CuTensor, CuDevice;
135     import grain.random : normal_;
136     import grain.ops : copy;
137     import grain.dpp.cuda_driver : cuLaunchKernel, CUstream;
138     
139     auto cufun = compile(
140         "vectorAdd",
141         q{const float *A, const float *B, float *C, int numElements},
142         q{
143             int i = blockDim.x * blockIdx.x + threadIdx.x;
144             if (i < numElements) {
145                 C[i] = A[i] + B[i];
146             }
147         });
148 
149     scope auto n = 50000;
150     auto ha = Tensor!(1, float)(n).normal_;
151     auto hb = Tensor!(1, float)(n).normal_;
152 
153     // auto da = ha.copy!"cuda";
154     // auto db = hb.copy!"cuda";
155     // auto dc = CuTensor!(1, float)(n);
156 
157     // int threadPerBlock = 256;
158     // int sharedMemBytes = 0;
159     // auto stream = CuDevice.get(dc.deviceId).stream;
160     // auto ps = [da.ptr, db.ptr, dc.ptr];
161     // scope void*[4] args = [
162     //     cast(void*) &ps[0],
163     //     cast(void*) &ps[1],
164     //     cast(void*) &ps[2],
165     //     cast(void*) &n
166     // ];
167     // void*[] config;
168     // // NOTE runtime api failed
169     // // import grain.dpp.cuda_runtime_api;
170     // // checkCuda(cudaLaunchKernel(
171     // //     cufun,
172     // //     // grid
173     // //     dim3(threadPerBlock, 1, 1),
174     // //     // block
175     // //     dim3((n + threadPerBlock - 1) / threadPerBlock, 1, 1),
176     // //     args.ptr,
177     // //     sharedMemBytes, stream));
178 
179     // // device api
180     // checkCuda(cuLaunchKernel(
181     //     cufun,
182     //     // grid
183     //     threadPerBlock, 1, 1,
184     //     // block
185     //     (n + threadPerBlock - 1) / threadPerBlock, 1, 1,
186     //     sharedMemBytes, cast(CUstream) stream, args.ptr, config.ptr));
187 
188     // auto hc = dc.copy!"cpu";
189     // assertAllClose(ha.asSlice + hb.asSlice, hc.asSlice);
190 }