1 module grain.random;
2 
3 import std.traits : isFloatingPoint;
4 
5 import mir.random : Random;
6 
7 import grain.tensor : Tensor;
8 
9 
10 /// thread global random number generator (rng)
11 __gshared Random rng = void;
12 
13 shared static this()
14 {
15     import mir.random : unpredictableSeed;
16     rng = Random(unpredictableSeed);
17 }
18 
19 /// set seed for random number generator
20 void setSeed(uint i)
21 {
22     synchronized
23     {
24         rng = Random(i);
25     }
26 }
27 
28 ///
29 unittest
30 {
31     import mir.ndslice : each;
32     import mir.random.variable: NormalVariable;
33     auto rv = NormalVariable!double(0, 1);
34     setSeed(0);
35     auto a = rv(rng);
36     auto b = rv(rng);
37     assert(a != b);
38     setSeed(0);
39     auto a2 = rv(rng);
40     assert(a == a2);
41 }
42 
43 
44 Tensor!(dim, T) normal_(size_t dim, T)(Tensor!(dim, T) a) if (isFloatingPoint!T)
45 {
46     import mir.ndslice : each;
47     import mir.random.variable: NormalVariable;
48 
49     auto rv = NormalVariable!T(0, 1);
50     a.asSlice.each!((ref x) {x = rv(rng);});
51     return a;
52 }