1
1
#include " ATen/ATen.h"
2
2
#include " ATen/CPUApplyUtils.h"
3
3
#include " ATen/Dispatch.h"
4
+ #include " ATen/Config.h"
4
5
#include " ATen/ExpandUtils.h"
5
6
#include " ATen/NativeFunctions.h"
6
7
7
8
#include " ATen/CPUGenerator.h"
8
9
#include " ATen/CheckGenerator.h"
9
10
#include " ATen/Generator.h"
10
11
11
- #include < ATen/native/Distributions.cuh >
12
+ #include < functional >
12
13
13
14
#include " TH/THRandom.h"
14
15
@@ -121,12 +122,23 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
121
122
122
123
/*
123
124
* This section is a counterpart to Distributions.cu
124
- *
125
125
*/
126
126
127
127
namespace dist {
128
- // The function `sample_poisson`
129
- // is adapted from Numpy's distributions.c implementation.
128
+
129
+ #if !AT_CUDA_ENABLED()
130
+ template <typename precision_t >
131
+ struct baseSampler {
132
+ std::function<precision_t (void )> sampler;
133
+ baseSampler (std::function<precision_t (void )> sampler): sampler(sampler) {}
134
+ precision_t sample () {
135
+ return sampler ();
136
+ }
137
+ };
138
+ #endif
139
+
140
+ // The functions `sample_poisson`, `sample_gamma`
141
+ // are adapted from Numpy's distributions.c implementation.
130
142
// It is MIT licensed, so here is the copyright:
131
143
132
144
/* Copyright 2005 Robert Kern ([email protected] )
@@ -151,6 +163,41 @@ namespace dist {
151
163
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
152
164
*/
153
165
166
+
167
+ template <typename precision_t >
168
+ #if AT_CUDA_ENABLED()
169
+ __host__ __device__
170
+ #endif
171
+ precision_t sample_gamma (precision_t alpha, baseSampler<precision_t >& standard_uniform, baseSampler<precision_t >& standard_normal) {
172
+
173
+ precision_t scale = 1.0 ;
174
+
175
+ // Boost alpha for higher acceptance probability.
176
+ if (alpha < 1.0 ) {
177
+ scale *= ::pow (1 - standard_uniform.sample (), 1.0 / alpha);
178
+ alpha += 1.0 ;
179
+ }
180
+
181
+ // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
182
+ // doi:10.1145/358407.358414
183
+ const precision_t d = alpha - 1.0 / 3.0 ;
184
+ const precision_t c = 1.0 / ::sqrt (9.0 * d);
185
+ for (;;) {
186
+ precision_t x, y;
187
+ do {
188
+ x = standard_normal.sample ();
189
+ y = 1.0 + c * x;
190
+ } while (y <= 0 );
191
+ const precision_t v = y * y * y;
192
+ const precision_t u = 1 - standard_uniform.sample ();
193
+ const precision_t xx = x * x;
194
+ if (u < 1.0 - 0.0331 * xx * xx)
195
+ return scale * d * v;
196
+ if (::log (u) < 0.5 * xx + d * (1.0 - v + ::log (v)))
197
+ return scale * d * v;
198
+ }
199
+ }
200
+
154
201
THGenerator * get_generator (Generator *gen) {
155
202
auto default_gen = &at::globalContext ().defaultGenerator (Backend::CPU);
156
203
auto gen_ = check_generator<CPUGenerator>(gen, default_gen);
0 commit comments