Skip to content

Commit 52bbf59

Browse files
authored
Modify atleast_Nd to accept only one positional argument (#1291)
1 parent f10a603 commit 52bbf59

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

pytensor/tensor/basic.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -4355,28 +4355,22 @@ def empty_like(
43554355

43564356

43574357
def atleast_Nd(
4358-
*arys: np.ndarray | TensorVariable, n: int = 1, left: bool = True
4358+
arry: np.ndarray | TensorVariable, *, n: int = 1, left: bool = True
43594359
) -> TensorVariable:
4360-
"""Convert inputs to arrays with at least `n` dimensions."""
4361-
res = []
4362-
for ary in arys:
4363-
ary = as_tensor(ary)
4360+
"""Convert input to an array with at least `n` dimensions."""
43644361

4365-
if ary.ndim >= n:
4366-
result = ary
4367-
else:
4368-
result = (
4369-
shape_padleft(ary, n - ary.ndim)
4370-
if left
4371-
else shape_padright(ary, n - ary.ndim)
4372-
)
4362+
arry = as_tensor(arry)
43734363

4374-
res.append(result)
4375-
4376-
if len(res) == 1:
4377-
return res[0]
4364+
if arry.ndim >= n:
4365+
result = arry
43784366
else:
4379-
return res
4367+
result = (
4368+
shape_padleft(arry, n - arry.ndim)
4369+
if left
4370+
else shape_padright(arry, n - arry.ndim)
4371+
)
4372+
4373+
return result
43804374

43814375

43824376
atleast_1d = partial(atleast_Nd, n=1)

tests/tensor/test_basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4364,7 +4364,8 @@ def test_atleast_Nd():
43644364

43654365
for n in range(1, 3):
43664366
ary1, ary2 = dscalar(), dvector()
4367-
res_ary1, res_ary2 = atleast_Nd(ary1, ary2, n=n)
4367+
res_ary1 = atleast_Nd(ary1, n=n)
4368+
res_ary2 = atleast_Nd(ary2, n=n)
43684369

43694370
assert res_ary1.ndim == n
43704371
if n == ary2.ndim:

0 commit comments

Comments
 (0)