@@ -1422,7 +1422,7 @@ def infer_shape(self, fgraph, node, shapes):
1422
1422
def _c_all (self , node , name , input_names , output_names , sub ):
1423
1423
[inp ] = node .inputs
1424
1424
[out ] = node .outputs
1425
- ndim = inp .type .ndim
1425
+ inp_ndim = inp .type .ndim
1426
1426
1427
1427
[inp_name ] = input_names
1428
1428
[out_name ] = output_names
@@ -1454,10 +1454,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
1454
1454
assert var .dtype == node .outputs [0 ].dtype
1455
1455
return var .owner .op ._c_all (var .owner , name , input_names , output_names , sub )
1456
1456
1457
- inp_dims = list (range (ndim ))
1457
+ inp_dims = list (range (inp_ndim ))
1458
1458
non_reduced_dims = [i for i in inp_dims if i not in axis ]
1459
- counter = iter (range (ndim ))
1460
- acc_dims = ["x" if i in axis else next (counter ) for i in range (ndim )]
1459
+ counter = iter (range (inp_ndim ))
1460
+ acc_dims = ["x" if i in axis else next (counter ) for i in range (inp_ndim )]
1461
1461
1462
1462
sub = sub .copy ()
1463
1463
sub ["lv0" ] = inp_name
@@ -1484,7 +1484,9 @@ def _c_all(self, node, name, input_names, output_names, sub):
1484
1484
cgen .make_declare (
1485
1485
[acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1486
1486
)
1487
- + cgen .make_alloc ([non_reduced_dims ], out_dtype , sub )
1487
+ + cgen .make_careduce_alloc (
1488
+ inp_name , out_name , inp_ndim , axis , out_dtype , sub ["fail" ]
1489
+ )
1488
1490
+ cgen .make_checks (
1489
1491
[acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1490
1492
)
@@ -1500,7 +1502,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
1500
1502
cgen .make_declare (
1501
1503
[acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1502
1504
)
1503
- + cgen .make_alloc ([non_reduced_dims ], acc_dtype , sub )
1505
+ + cgen .make_careduce_alloc (
1506
+ inp_name , acc_name , inp_ndim , axis , out_dtype , sub ["fail" ]
1507
+ )
1508
+ + cgen .make_careduce_alloc ([non_reduced_dims ], acc_dtype , sub )
1504
1509
+ cgen .make_checks (
1505
1510
[acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1506
1511
)
@@ -1524,8 +1529,6 @@ def _c_all(self, node, name, input_names, output_names, sub):
1524
1529
elif identity is None :
1525
1530
raise TypeError (f"The { self .scalar_op } does not define an identity." )
1526
1531
1527
- initial_value = f"{ acc_name } _i = { identity } ;"
1528
-
1529
1532
inner_task = self .scalar_op .c_code (
1530
1533
Apply (
1531
1534
self .scalar_op ,
@@ -1544,28 +1547,16 @@ def _c_all(self, node, name, input_names, output_names, sub):
1544
1547
sub ,
1545
1548
)
1546
1549
1547
- if out .type .ndim == 0 :
1548
- # Simple case where everything is reduced, no need for loop ordering
1549
- loop = cgen .make_complete_loop_careduce (
1550
- inp_var = inp_name ,
1551
- acc_var = acc_name ,
1552
- inp_dtype = inp_dtype ,
1553
- acc_dtype = acc_dtype ,
1554
- initial_value = initial_value ,
1555
- inner_task = inner_task ,
1556
- fail_code = sub ["fail" ],
1557
- )
1558
- else :
1559
- loop = cgen .make_reordered_loop_careduce (
1560
- inp_var = inp_name ,
1561
- acc_var = acc_name ,
1562
- inp_dtype = inp_dtype ,
1563
- acc_dtype = acc_dtype ,
1564
- inp_ndim = ndim ,
1565
- reduction_axes = axis ,
1566
- initial_value = initial_value ,
1567
- inner_task = inner_task ,
1568
- )
1550
+ loop = cgen .make_reordered_loop_careduce (
1551
+ inp_var = inp_name ,
1552
+ acc_var = acc_name ,
1553
+ inp_dtype = inp_dtype ,
1554
+ acc_dtype = acc_dtype ,
1555
+ inp_ndim = inp_ndim ,
1556
+ reduction_axes = axis ,
1557
+ initial_value = identity ,
1558
+ inner_task = inner_task ,
1559
+ )
1569
1560
1570
1561
if acc_dtype != out_dtype :
1571
1562
cast = dedent (
@@ -1589,7 +1580,7 @@ def c_headers(self, **kwargs):
1589
1580
1590
1581
def c_code_cache_version_apply (self , node ):
1591
1582
# the version corresponding to the c code in this Op
1592
- version = [10 ]
1583
+ version = [11 ]
1593
1584
1594
1585
# now we insert versions for the ops on which we depend...
1595
1586
scalar_node = Apply (
0 commit comments