1 /// Test functions for CUDA
2 module grain.cuda.testing;
3 
4 import std..string : fromStringz;
5 import mir.format : stringBuf, getData;
6 
7 import grain.dpp.cuda_driver;
8 import grain.dpp.cublas;
9 import grain.dpp.nvrtc;
10 
11 
12 /// emit error message string from enum
13 auto cublasGetErrorEnum()(cublasStatus_t error) {
14     switch (error) {
15     case CUBLAS_STATUS_SUCCESS:
16         return "CUBLAS_STATUS_SUCCESS";
17 
18     case CUBLAS_STATUS_NOT_INITIALIZED:
19         return "CUBLAS_STATUS_NOT_INITIALIZED";
20 
21     case CUBLAS_STATUS_ALLOC_FAILED:
22         return "CUBLAS_STATUS_ALLOC_FAILED";
23 
24     case CUBLAS_STATUS_INVALID_VALUE:
25         return "CUBLAS_STATUS_INVALID_VALUE";
26 
27     case CUBLAS_STATUS_ARCH_MISMATCH:
28         return "CUBLAS_STATUS_ARCH_MISMATCH";
29 
30     case CUBLAS_STATUS_MAPPING_ERROR:
31         return "CUBLAS_STATUS_MAPPING_ERROR";
32 
33     case CUBLAS_STATUS_EXECUTION_FAILED:
34         return "CUBLAS_STATUS_EXECUTION_FAILED";
35 
36     case CUBLAS_STATUS_INTERNAL_ERROR:
37         return "CUBLAS_STATUS_INTERNAL_ERROR";
38     default:
39         return "CUBLAS UNKNOWN ERROR";
40     }
41 }
42 
43 /// cublas error checker
44 @nogc
45 void checkCublas(
46     string func = __FUNCTION__,
47     string file = __FILE__,
48     size_t line = __LINE__
49 )(cublasStatus_t err)
50 {
51     assert(
52         err == CUBLAS_STATUS_SUCCESS,
53         stringBuf()
54         << cublasGetErrorEnum(err)
55         << " (func) " << func
56         << " (file) " << file
57         << " (line) " << line
58         << getData);
59 }
60 
61 
62 import grain.dpp.cudnn : cudnnStatus_t, CUDNN_STATUS_SUCCESS, cudnnGetErrorString;
63 
64 /// cudnn error checker
65 @nogc
66 void checkCudnn(
67     string func = __FUNCTION__,
68     string file = __FILE__,
69     size_t line = __LINE__
70 )(cudnnStatus_t err)
71 {
72     assert(
73         err == CUDNN_STATUS_SUCCESS,
74         stringBuf()
75         << cudnnGetErrorString(err).fromStringz
76         << " (func) " << func
77         << " (file) " << file
78         << " (line) " << line
79         << getData);
80 }
81 
82 
83 /// cuda error checker
84 @nogc
85 void checkCuda(
86     string file = __FILE__,
87     size_t line = __LINE__,
88     string func = __FUNCTION__
89 )(CUresult err)
90 {
91     if (err == CUDA_SUCCESS) return;
92     const(char)* name, content;
93     cuGetErrorName(err, &name);
94     cuGetErrorString(err, &content);
95     assert(err == CUDA_SUCCESS,
96            stringBuf()
97            << name.fromStringz
98            << " (info) " << content.fromStringz
99            << " (func) " << func
100            << " (file) " << file
101            << " (line) " << line
102            << getData);
103 }
104 
105 import R = grain.dpp.cuda_runtime_api;
106 @nogc
107 void checkCuda(
108     string file = __FILE__,
109     size_t line = __LINE__,
110     string func = __FUNCTION__
111 )(R.cudaError err)
112 {
113     if (err == R.cudaSuccess) return;
114     // const(char)* name, content;
115     auto name = R.cudaGetErrorName(err);
116     auto content = R.cudaGetErrorString(err);
117     assert(err == R.cudaSuccess,
118            stringBuf()
119            << name.fromStringz
120            << " (info) " << content.fromStringz
121            << " (func) " << func
122            << " (file) " << file
123            << " (line) " << line
124            << getData);
125 }
126 
127 
128 /// nvrtc error checker
129 @nogc
130 void checkNvrtc(
131     string file = __FILE__,
132     size_t line = __LINE__,
133     string func = __FUNCTION__
134 )(nvrtcResult err, const(char)[] info = "")
135 {
136     if (err == NVRTC_SUCCESS) return;
137     if (info != "")
138     {
139         assert(err == NVRTC_SUCCESS,
140                stringBuf()
141                << nvrtcGetErrorString(err).fromStringz
142                << " (func) " << func
143                << " (file) " << file
144                << " (line) " << line
145                << " (info) " << info
146                << getData);
147     }
148     else
149     {
150         assert(err == NVRTC_SUCCESS,
151                stringBuf()
152                << nvrtcGetErrorString(err).fromStringz
153                << " (func) " << func
154                << " (file) " << file
155                << " (line) " << line
156                << getData);
157     }
158 }