Skip to content

Commit 8112576

Browse files
committed
Stop using deprecated numpy.product
1 parent 22e9233 commit 8112576

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

pytensor/scalar/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1895,7 +1895,7 @@ class Mul(ScalarOp):
18951895
nfunc_variadic = "product"
18961896

18971897
def impl(self, *inputs):
1898-
return np.product(inputs)
1898+
return np.prod(inputs)
18991899

19001900
def c_code(self, node, name, inputs, outputs, sub):
19011901
(z,) = outputs

tests/link/jax/test_extra_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_extra_ops():
5555
fgraph = FunctionGraph([a], [out])
5656
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
5757

58-
indices = np.arange(np.product((3, 4)))
58+
indices = np.arange(np.prod((3, 4)))
5959
out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
6060
fgraph = FunctionGraph([], out)
6161
compare_jax_and_py(
@@ -100,7 +100,7 @@ def test_extra_ops_omni():
100100
fgraph = FunctionGraph([], [out])
101101
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
102102

103-
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
103+
multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
104104
out = at_extra_ops.ravel_multi_index(multi_index, (3, 4))
105105
fgraph = FunctionGraph([], [out])
106106
compare_jax_and_py(

tests/tensor/test_extra_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def test_infer_shape(self, x, inp, axis):
925925
class TestUnravelIndex(utt.InferShapeTester):
926926
def test_unravel_index(self):
927927
def check(shape, index_ndim, order):
928-
indices = np.arange(np.product(shape))
928+
indices = np.arange(np.prod(shape))
929929
# test with scalars and higher-dimensional indices
930930
if index_ndim == 0:
931931
indices = indices[-1]
@@ -996,7 +996,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
996996
def test_ravel_multi_index(self):
997997
def check(shape, index_ndim, mode, order):
998998
multi_index = np.unravel_index(
999-
np.arange(np.product(shape)), shape, order=order
999+
np.arange(np.prod(shape)), shape, order=order
10001000
)
10011001
# create some invalid indices to test the mode
10021002
if mode in ("wrap", "clip"):

tests/tensor/test_subtensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def test_advanced1_inc_and_set(self):
11511151
for inplace in (False, True):
11521152
for data_shape in ((10,), (4, 5), (1, 2, 3), (4, 5, 6, 7)):
11531153
data_n_dims = len(data_shape)
1154-
data_size = np.product(data_shape)
1154+
data_size = np.prod(data_shape)
11551155
# Corresponding numeric variable.
11561156
data_num_init = np.arange(data_size, dtype=self.dtype)
11571157
data_num_init = data_num_init.reshape(data_shape)
@@ -1203,7 +1203,7 @@ def test_advanced1_inc_and_set(self):
12031203
# The param dtype is needed when inc_shape is empty.
12041204
# By default, it would return a float and rng.uniform
12051205
# with NumPy 1.10 will raise a Deprecation warning.
1206-
inc_size = np.product(inc_shape, dtype="int")
1206+
inc_size = np.prod(inc_shape, dtype="int")
12071207
# Corresponding numeric variable.
12081208
inc_num = rng.uniform(size=inc_size).astype(self.dtype)
12091209
inc_num = inc_num.reshape(inc_shape)

0 commit comments

Comments
 (0)