Skip to content

Commit 3af923b

Browse files
authored
Fix nan in jax implementation of Multinomial (#1328)
1 parent 0b56ed9 commit 3af923b

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

pytensor/link/jax/dispatch/random.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,14 @@ def sample_fn(rng_key, size, dtype, n, p):
409409
sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
410410

411411
def _binomial_sample_fn(carry, p_rng):
412-
s, rho = carry
412+
remaining_n, remaining_p = carry
413413
p, rng = p_rng
414-
samples = jax.random.binomial(rng, s, p / rho)
415-
s = s - samples
416-
rho = rho - p
417-
return ((s, rho), samples)
414+
samples = jnp.where(
415+
p == 0, 0, jax.random.binomial(rng, remaining_n, p / remaining_p)
416+
)
417+
remaining_n -= samples
418+
remaining_p -= p
419+
return ((remaining_n, remaining_p), samples)
418420

419421
(remain, _), samples = jax.lax.scan(
420422
_binomial_sample_fn,

tests/link/jax/test_random.py

+12
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,18 @@ def test_multinomial():
733733
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
734734
)
735735

736+
# Test with p=0
737+
g = pt.random.multinomial(n=5, p=pt.eye(4))
738+
g_fn = compile_random_function([], g, mode="JAX")
739+
samples = g_fn()
740+
np.testing.assert_array_equal(samples, np.eye(4) * 5)
741+
742+
# Test with n=0
743+
g = pt.random.multinomial(n=0, p=np.ones(4) / 4)
744+
g_fn = compile_random_function([], g, mode="JAX")
745+
samples = g_fn()
746+
np.testing.assert_array_equal(samples, np.zeros(4))
747+
736748

737749
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
738750
def test_vonmises_mu_outside_circle():

0 commit comments

Comments
 (0)