Skip to content

Commit ac891dc

Browse files
authored
Update instance_normalization.py
switch to onnx helper func tf_shape instead of tf.shape
1 parent 610feae commit ac891dc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnx_tf/handlers/backend/instance_normalization.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from onnx_tf.handlers.backend_handler import BackendHandler
44
from onnx_tf.handlers.handler import onnx_op
55
from onnx_tf.handlers.handler import tf_func
6+
from onnx_tf.common.tf_helper import tf_shape
67

78

89
@onnx_op("InstanceNormalization")
@@ -31,7 +32,7 @@ def _common(cls, node, **kwargs):
3132
beta = tensor_dict[node.inputs[2]]
3233

3334
inputs = tensor_dict[node.inputs[0]]
34-
inputs_shape = tf.shape(inputs)
35+
inputs_shape = tf_shape(inputs)
3536
inputs_rank = inputs.shape.ndims
3637

3738
moments_axes = list(range(inputs_rank))[2:]

0 commit comments

Comments
 (0)