Skip to content

Commit 293af91

Browse files
committed
Make CAReduce more SIMD friendly and do better allocation of output memory
1 parent a2b7985 commit 293af91

File tree

3 files changed

+240
-188
lines changed

3 files changed

+240
-188
lines changed

pytensor/tensor/elemwise.py

+22-31
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ def infer_shape(self, fgraph, node, shapes):
14221422
def _c_all(self, node, name, input_names, output_names, sub):
14231423
[inp] = node.inputs
14241424
[out] = node.outputs
1425-
ndim = inp.type.ndim
1425+
inp_ndim = inp.type.ndim
14261426

14271427
[inp_name] = input_names
14281428
[out_name] = output_names
@@ -1454,10 +1454,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
14541454
assert var.dtype == node.outputs[0].dtype
14551455
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)
14561456

1457-
inp_dims = list(range(ndim))
1457+
inp_dims = list(range(inp_ndim))
14581458
non_reduced_dims = [i for i in inp_dims if i not in axis]
1459-
counter = iter(range(ndim))
1460-
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]
1459+
counter = iter(range(inp_ndim))
1460+
acc_dims = ["x" if i in axis else next(counter) for i in range(inp_ndim)]
14611461

14621462
sub = sub.copy()
14631463
sub["lv0"] = inp_name
@@ -1484,7 +1484,9 @@ def _c_all(self, node, name, input_names, output_names, sub):
14841484
cgen.make_declare(
14851485
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
14861486
)
1487-
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
1487+
+ cgen.make_careduce_alloc(
1488+
inp_name, out_name, inp_ndim, axis, out_dtype, sub["fail"]
1489+
)
14881490
+ cgen.make_checks(
14891491
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
14901492
)
@@ -1500,7 +1502,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
15001502
cgen.make_declare(
15011503
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
15021504
)
1503-
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
1505+
+ cgen.make_careduce_alloc(
1506+
inp_name, acc_name, inp_ndim, axis, out_dtype, sub["fail"]
1507+
)
1508+
+ cgen.make_careduce_alloc([non_reduced_dims], acc_dtype, sub)
15041509
+ cgen.make_checks(
15051510
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
15061511
)
@@ -1524,8 +1529,6 @@ def _c_all(self, node, name, input_names, output_names, sub):
15241529
elif identity is None:
15251530
raise TypeError(f"The {self.scalar_op} does not define an identity.")
15261531

1527-
initial_value = f"{acc_name}_i = {identity};"
1528-
15291532
inner_task = self.scalar_op.c_code(
15301533
Apply(
15311534
self.scalar_op,
@@ -1544,28 +1547,16 @@ def _c_all(self, node, name, input_names, output_names, sub):
15441547
sub,
15451548
)
15461549

1547-
if out.type.ndim == 0:
1548-
# Simple case where everything is reduced, no need for loop ordering
1549-
loop = cgen.make_complete_loop_careduce(
1550-
inp_var=inp_name,
1551-
acc_var=acc_name,
1552-
inp_dtype=inp_dtype,
1553-
acc_dtype=acc_dtype,
1554-
initial_value=initial_value,
1555-
inner_task=inner_task,
1556-
fail_code=sub["fail"],
1557-
)
1558-
else:
1559-
loop = cgen.make_reordered_loop_careduce(
1560-
inp_var=inp_name,
1561-
acc_var=acc_name,
1562-
inp_dtype=inp_dtype,
1563-
acc_dtype=acc_dtype,
1564-
inp_ndim=ndim,
1565-
reduction_axes=axis,
1566-
initial_value=initial_value,
1567-
inner_task=inner_task,
1568-
)
1550+
loop = cgen.make_reordered_loop_careduce(
1551+
inp_var=inp_name,
1552+
acc_var=acc_name,
1553+
inp_dtype=inp_dtype,
1554+
acc_dtype=acc_dtype,
1555+
inp_ndim=inp_ndim,
1556+
reduction_axes=axis,
1557+
initial_value=identity,
1558+
inner_task=inner_task,
1559+
)
15691560

15701561
if acc_dtype != out_dtype:
15711562
cast = dedent(
@@ -1589,7 +1580,7 @@ def c_headers(self, **kwargs):
15891580

15901581
def c_code_cache_version_apply(self, node):
15911582
# the version corresponding to the c code in this Op
1592-
version = [10]
1583+
version = [11]
15931584

15941585
# now we insert versions for the ops on which we depend...
15951586
scalar_node = Apply(

0 commit comments

Comments
 (0)