Skip to content

Commit 6ff5d33

Browse files
committed
Attempt fixes for AT_CUDA_ENABLED changes
1 parent d7488f4 commit 6ff5d33

File tree

4 files changed

+80
-68
lines changed

4 files changed

+80
-68
lines changed

aten/src/ATen/SharedDist.cu

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/TensorUtils.h"
3+
#include "ATen/NativeFunctions.h"
4+
#include "ATen/Dispatch.h"
5+
#include "ATen/Config.h"
6+
7+
#include <nvfunctional>
8+
9+
namespace at {
10+
namespace native {
11+
namespace dist {
12+
template<typename precision_t>
13+
struct baseSampler {
14+
nvstd::function<precision_t(void)> sampler;
15+
baseSampler(nvstd::function<precision_t(void)> sampler): sampler(sampler) {}
16+
precision_t sample() {
17+
return sampler();
18+
}
19+
};
20+
}
21+
}
22+
}
23+
24+
// this version is only linked if CUDA is enabled, so we can safely just use CUDA features here

aten/src/ATen/native/Distributions.cpp

+51-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#include "ATen/ATen.h"
22
#include "ATen/CPUApplyUtils.h"
33
#include "ATen/Dispatch.h"
4+
#include "ATen/Config.h"
45
#include "ATen/ExpandUtils.h"
56
#include "ATen/NativeFunctions.h"
67

78
#include "ATen/CPUGenerator.h"
89
#include "ATen/CheckGenerator.h"
910
#include "ATen/Generator.h"
1011

11-
#include <ATen/native/Distributions.cuh>
12+
#include <functional>
1213

1314
#include "TH/THRandom.h"
1415

@@ -121,12 +122,23 @@ Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
121122

122123
/*
123124
* This section is a counterpart to Distributions.cu
124-
*
125125
*/
126126

127127
namespace dist {
128-
// The function `sample_poisson`
129-
// is adapted from Numpy's distributions.c implementation.
128+
129+
#if !AT_CUDA_ENABLED()
130+
template<typename precision_t>
131+
struct baseSampler {
132+
std::function<precision_t(void)> sampler;
133+
baseSampler(std::function<precision_t(void)> sampler): sampler(sampler) {}
134+
precision_t sample() {
135+
return sampler();
136+
}
137+
};
138+
#endif
139+
140+
// The functions `sample_poisson`, `sample_gamma`
141+
// are adapted from Numpy's distributions.c implementation.
130142
// It is MIT licensed, so here is the copyright:
131143

132144
/* Copyright 2005 Robert Kern ([email protected])
@@ -151,6 +163,41 @@ namespace dist {
151163
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
152164
*/
153165

166+
167+
template<typename precision_t>
168+
#if AT_CUDA_ENABLED()
169+
__host__ __device__
170+
#endif
171+
precision_t sample_gamma(precision_t alpha, baseSampler<precision_t>& standard_uniform, baseSampler<precision_t>& standard_normal) {
172+
173+
precision_t scale = 1.0;
174+
175+
// Boost alpha for higher acceptance probability.
176+
if (alpha < 1.0) {
177+
scale *= ::pow(1 - standard_uniform.sample(), 1.0 / alpha);
178+
alpha += 1.0;
179+
}
180+
181+
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
182+
// doi:10.1145/358407.358414
183+
const precision_t d = alpha - 1.0 / 3.0;
184+
const precision_t c = 1.0 / ::sqrt(9.0 * d);
185+
for (;;) {
186+
precision_t x, y;
187+
do {
188+
x = standard_normal.sample();
189+
y = 1.0 + c * x;
190+
} while (y <= 0);
191+
const precision_t v = y * y * y;
192+
const precision_t u = 1 - standard_uniform.sample();
193+
const precision_t xx = x * x;
194+
if (u < 1.0 - 0.0331 * xx * xx)
195+
return scale * d * v;
196+
if (::log(u) < 0.5 * xx + d * (1.0 - v + ::log(v)))
197+
return scale * d * v;
198+
}
199+
}
200+
154201
THGenerator * get_generator(Generator *gen) {
155202
auto default_gen = &at::globalContext().defaultGenerator(Backend::CPU);
156203
auto gen_ = check_generator<CPUGenerator>(gen, default_gen);

aten/src/ATen/native/Distributions.cuh

-63
This file was deleted.

aten/src/ATen/native/cuda/Distributions.cu

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
#include "ATen/ATen.h"
2+
#include "ATen/TensorUtils.h"
13
#include "ATen/NativeFunctions.h"
24
#include "ATen/Dispatch.h"
5+
#include "ATen/Config.h"
36
#include "ATen/cuda/CUDAApplyUtils.cuh"
47
#include <curand.h>
58
#include <curand_kernel.h>
@@ -8,7 +11,8 @@
811
#include <functional>
912
#include <nvfunctional>
1013

11-
#include "ATen/native/Distributions.cuh"
14+
#include "ATen/SharedDist.cu"
15+
#include "ATen/native/Distributions.cpp"
1216

1317
#include <TH/THAtomic.h>
1418

0 commit comments

Comments
 (0)