1 module grain.random; 2 3 import std.traits : isFloatingPoint; 4 5 import mir.random : Random; 6 7 import grain.tensor : Tensor; 8 9 10 /// thread global random number generator (rng) 11 __gshared Random rng = void; 12 13 shared static this() 14 { 15 import mir.random : unpredictableSeed; 16 rng = Random(unpredictableSeed); 17 } 18 19 /// set seed for random number generator 20 void setSeed(uint i) 21 { 22 synchronized 23 { 24 rng = Random(i); 25 } 26 } 27 28 /// 29 unittest 30 { 31 import mir.ndslice : each; 32 import mir.random.variable: NormalVariable; 33 auto rv = NormalVariable!double(0, 1); 34 setSeed(0); 35 auto a = rv(rng); 36 auto b = rv(rng); 37 assert(a != b); 38 setSeed(0); 39 auto a2 = rv(rng); 40 assert(a == a2); 41 } 42 43 44 Tensor!(dim, T) normal_(size_t dim, T)(Tensor!(dim, T) a) if (isFloatingPoint!T) 45 { 46 import mir.ndslice : each; 47 import mir.random.variable: NormalVariable; 48 49 auto rv = NormalVariable!T(0, 1); 50 a.asSlice.each!((ref x) {x = rv(rng);}); 51 return a; 52 }