|
103 | 103 | from pytensor.link.c.params_type import ParamsType
|
104 | 104 | from pytensor.printing import FunctionPrinter, pprint
|
105 | 105 | from pytensor.scalar import bool as bool_t
|
106 |
| -from pytensor.tensor import basic as ptb |
| 106 | +from pytensor.tensor.basic import as_tensor_variable, cast |
107 | 107 | from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
|
108 | 108 | from pytensor.tensor.math import dot, tensordot
|
109 | 109 | from pytensor.tensor.shape import specify_broadcastable
|
@@ -157,11 +157,11 @@ def __str__(self):
|
157 | 157 | return f"{self.__class__.__name__}{{no_inplace}}"
|
158 | 158 |
|
159 | 159 | def make_node(self, y, alpha, A, x, beta):
|
160 |
| - y = ptb.as_tensor_variable(y) |
161 |
| - x = ptb.as_tensor_variable(x) |
162 |
| - A = ptb.as_tensor_variable(A) |
163 |
| - alpha = ptb.as_tensor_variable(alpha) |
164 |
| - beta = ptb.as_tensor_variable(beta) |
| 160 | + y = as_tensor_variable(y) |
| 161 | + x = as_tensor_variable(x) |
| 162 | + A = as_tensor_variable(A) |
| 163 | + alpha = as_tensor_variable(alpha) |
| 164 | + beta = as_tensor_variable(beta) |
165 | 165 | if y.dtype != A.dtype or y.dtype != x.dtype:
|
166 | 166 | raise TypeError(
|
167 | 167 | "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype)
|
@@ -257,10 +257,10 @@ def __str__(self):
|
257 | 257 | return f"{self.__class__.__name__}{{non-destructive}}"
|
258 | 258 |
|
259 | 259 | def make_node(self, A, alpha, x, y):
|
260 |
| - A = ptb.as_tensor_variable(A) |
261 |
| - y = ptb.as_tensor_variable(y) |
262 |
| - x = ptb.as_tensor_variable(x) |
263 |
| - alpha = ptb.as_tensor_variable(alpha) |
| 260 | + A = as_tensor_variable(A) |
| 261 | + y = as_tensor_variable(y) |
| 262 | + x = as_tensor_variable(x) |
| 263 | + alpha = as_tensor_variable(alpha) |
264 | 264 | if not (A.dtype == x.dtype == y.dtype == alpha.dtype):
|
265 | 265 | raise TypeError(
|
266 | 266 | "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype)
|
@@ -859,7 +859,7 @@ def __getstate__(self):
|
859 | 859 | return rval
|
860 | 860 |
|
861 | 861 | def make_node(self, *inputs):
|
862 |
| - inputs = list(map(ptb.as_tensor_variable, inputs)) |
| 862 | + inputs = list(map(as_tensor_variable, inputs)) |
863 | 863 |
|
864 | 864 | if any(not isinstance(i.type, DenseTensorType) for i in inputs):
|
865 | 865 | raise NotImplementedError("Only dense tensor types are supported")
|
@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated):
|
1129 | 1129 | check_input = False
|
1130 | 1130 |
|
1131 | 1131 | def make_node(self, x, y):
|
1132 |
| - x = ptb.as_tensor_variable(x) |
1133 |
| - y = ptb.as_tensor_variable(y) |
| 1132 | + x = as_tensor_variable(x) |
| 1133 | + y = as_tensor_variable(y) |
1134 | 1134 |
|
1135 | 1135 | if any(not isinstance(i.type, DenseTensorType) for i in (x, y)):
|
1136 | 1136 | raise NotImplementedError("Only dense tensor types are supported")
|
@@ -1322,8 +1322,8 @@ class BatchedDot(COp):
|
1322 | 1322 | gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)"
|
1323 | 1323 |
|
1324 | 1324 | def make_node(self, x, y):
|
1325 |
| - x = ptb.as_tensor_variable(x) |
1326 |
| - y = ptb.as_tensor_variable(y) |
| 1325 | + x = as_tensor_variable(x) |
| 1326 | + y = as_tensor_variable(y) |
1327 | 1327 |
|
1328 | 1328 | if not (
|
1329 | 1329 | isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType)
|
@@ -1357,7 +1357,7 @@ def extract_static_dim(dim_x, dim_y):
|
1357 | 1357 |
|
1358 | 1358 | # Change dtype if needed
|
1359 | 1359 | dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype)
|
1360 |
| - x, y = ptb.cast(x, dtype), ptb.cast(y, dtype) |
| 1360 | + x, y = cast(x, dtype), cast(y, dtype) |
1361 | 1361 | out = tensor(dtype=dtype, shape=out_shape)
|
1362 | 1362 | return Apply(self, [x, y], [out])
|
1363 | 1363 |
|
@@ -1738,7 +1738,7 @@ def batched_dot(a, b):
|
1738 | 1738 | "Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
|
1739 | 1739 | FutureWarning,
|
1740 | 1740 | )
|
1741 |
| - a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b) |
| 1741 | + a, b = as_tensor_variable(a), as_tensor_variable(b) |
1742 | 1742 |
|
1743 | 1743 | if a.ndim == 0:
|
1744 | 1744 | raise TypeError("a must have at least one (batch) axis")
|
|
0 commit comments