Skip to content

Commit f10a603

Browse files
committed
Use direct imports in blas.py
1 parent efd9f49 commit f10a603

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

pytensor/tensor/blas.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
from pytensor.link.c.params_type import ParamsType
104104
from pytensor.printing import FunctionPrinter, pprint
105105
from pytensor.scalar import bool as bool_t
106-
from pytensor.tensor import basic as ptb
106+
from pytensor.tensor.basic import as_tensor_variable, cast
107107
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
108108
from pytensor.tensor.math import dot, tensordot
109109
from pytensor.tensor.shape import specify_broadcastable
@@ -157,11 +157,11 @@ def __str__(self):
157157
return f"{self.__class__.__name__}{{no_inplace}}"
158158

159159
def make_node(self, y, alpha, A, x, beta):
160-
y = ptb.as_tensor_variable(y)
161-
x = ptb.as_tensor_variable(x)
162-
A = ptb.as_tensor_variable(A)
163-
alpha = ptb.as_tensor_variable(alpha)
164-
beta = ptb.as_tensor_variable(beta)
160+
y = as_tensor_variable(y)
161+
x = as_tensor_variable(x)
162+
A = as_tensor_variable(A)
163+
alpha = as_tensor_variable(alpha)
164+
beta = as_tensor_variable(beta)
165165
if y.dtype != A.dtype or y.dtype != x.dtype:
166166
raise TypeError(
167167
"Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
@@ -257,10 +257,10 @@ def __str__(self):
257257
return f"{self.__class__.__name__}{{non-destructive}}"
258258

259259
def make_node(self, A, alpha, x, y):
260-
A = ptb.as_tensor_variable(A)
261-
y = ptb.as_tensor_variable(y)
262-
x = ptb.as_tensor_variable(x)
263-
alpha = ptb.as_tensor_variable(alpha)
260+
A = as_tensor_variable(A)
261+
y = as_tensor_variable(y)
262+
x = as_tensor_variable(x)
263+
alpha = as_tensor_variable(alpha)
264264
if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
265265
raise TypeError(
266266
"ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
@@ -859,7 +859,7 @@ def __getstate__(self):
859859
return rval
860860

861861
def make_node(self, *inputs):
862-
inputs = list(map(ptb.as_tensor_variable, inputs))
862+
inputs = list(map(as_tensor_variable, inputs))
863863

864864
if any(not isinstance(i.type, DenseTensorType) for i in inputs):
865865
raise NotImplementedError("Only dense tensor types are supported")
@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated):
11291129
check_input = False
11301130

11311131
def make_node(self, x, y):
1132-
x = ptb.as_tensor_variable(x)
1133-
y = ptb.as_tensor_variable(y)
1132+
x = as_tensor_variable(x)
1133+
y = as_tensor_variable(y)
11341134

11351135
if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
11361136
raise NotImplementedError("Only dense tensor types are supported")
@@ -1322,8 +1322,8 @@ class BatchedDot(COp):
13221322
gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
13231323

13241324
def make_node(self, x, y):
1325-
x = ptb.as_tensor_variable(x)
1326-
y = ptb.as_tensor_variable(y)
1325+
x = as_tensor_variable(x)
1326+
y = as_tensor_variable(y)
13271327

13281328
if not (
13291329
isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
@@ -1357,7 +1357,7 @@ def extract_static_dim(dim_x, dim_y):
13571357

13581358
# Change dtype if needed
13591359
dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
1360-
x, y = ptb.cast(x, dtype), ptb.cast(y, dtype)
1360+
x, y = cast(x, dtype), cast(y, dtype)
13611361
out = tensor(dtype=dtype, shape=out_shape)
13621362
return Apply(self, [x, y], [out])
13631363

@@ -1738,7 +1738,7 @@ def batched_dot(a, b):
17381738
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
17391739
FutureWarning,
17401740
)
1741-
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
1741+
a, b = as_tensor_variable(a), as_tensor_variable(b)
17421742

17431743
if a.ndim == 0:
17441744
raise TypeError("a must have at least one (batch) axis")

0 commit comments

Comments
 (0)