From 979ed328f19b5e52946ecc01159d094580916712 Mon Sep 17 00:00:00 2001 From: Swarnim Shekhar Date: Tue, 4 Mar 2025 22:00:32 +0530 Subject: [PATCH] Rewrite scalar dot as multiplication #1205 --- pytensor/tensor/rewriting/math.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..558abb4460 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node): new_b = b else: return None + + # new condition to handle (1,1) @ (1,1) + if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1): + return [a * b] # Direct elementwise multiplication new_a = copy_stack_trace(a, new_a) new_b = copy_stack_trace(b, new_b)