Skip to content

Commit 09bba9a

Browse files
committed
use tf.shape instead of .shape for dynamic axes
Signed-off-by: masakistan <[email protected]>
2 parents f884d78 + f616d65 commit 09bba9a

File tree

4 files changed

+187
-19
lines changed

4 files changed

+187
-19
lines changed
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import unittest
7+
import numpy as np
8+
9+
import onnx
10+
from onnx import helper
11+
from onnx import TensorProto
12+
import tensorflow as tf
13+
import onnxruntime.backend as ort
14+
15+
import onnx_tf.backend as otf
16+
from onnx_tf.common import data_type
17+
18+
19+
def find_between(s, first, last):
20+
try:
21+
start = s.index(first)
22+
end = s.index(last) + len(last)
23+
return s[start:end]
24+
except ValueError:
25+
return ""
26+
27+
28+
class TestMnistModel(unittest.TestCase):
29+
# Make sure the onnx file path is correct, assuming copied to the
30+
# current directory
31+
model_path = 'mnist-8.onnx'
32+
33+
def test(self):
34+
_model = onnx.load(self.model_path)
35+
print("Total node count in model: ", len(_model.graph.node))
36+
37+
# The input tensors could be provided as constants
38+
# The example below illustrates such a dictionary could be
39+
# provided for models with unknown input shapes. Since
40+
# mnist has known input shape, we don't provide input tensors.
41+
# input_tensors = {'Input3': tf.constant(0, dtype = tf.float32,
42+
# name='Input3',
43+
# shape=[1, 1, 28, 28])}
44+
input_tensors = {}
45+
tensor_dict = otf.prepare(_model,
46+
gen_tensor_dict=True,
47+
input_tensor_dict=input_tensors).tensor_dict
48+
more_outputs = []
49+
output_to_check = []
50+
for node in _model.graph.node:
51+
# add the first output of each node to the model output
52+
output_tensor = None
53+
for i in range(len(_model.graph.value_info)):
54+
if _model.graph.value_info[i].name == node.output[0]:
55+
output_tensor = _model.graph.value_info[i]
56+
57+
for i in range(len(_model.graph.initializer)):
58+
if _model.graph.initializer[i].name == node.output[0]:
59+
output_tensor = _model.graph.initializer[i]
60+
61+
# assume the first output is a tensor
62+
tensor = tensor_dict[node.output[0]]
63+
output_tensor = helper.make_tensor_value_info(
64+
node.output[0], data_type.tf2onnx(tensor.dtype),
65+
tensor.shape) if output_tensor is None else output_tensor
66+
more_outputs.append(output_tensor)
67+
output_to_check.append(node.output[0])
68+
_model.graph.output.extend(more_outputs)
69+
70+
tf_rep = otf.prepare(_model)
71+
rt_rep = ort.prepare(_model)
72+
73+
# prepare input data
74+
mnist = tf.keras.datasets.mnist
75+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
76+
x_train, x_test = x_train / 255.0, x_test / 255.0
77+
sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32)
78+
79+
inputs = [sample]
80+
my_out = tf_rep.run(inputs)
81+
rt_out = rt_rep.run(inputs)
82+
83+
for op in output_to_check:
84+
for i in range(len(my_out)):
85+
# find the index of output in the list
86+
if my_out[op] is my_out[i]:
87+
88+
try:
89+
np.savetxt(op.replace("/", "__") + ".rt",
90+
rt_out[i].flatten(),
91+
delimiter='\t')
92+
np.savetxt(op.replace("/", "__") + ".tf",
93+
my_out[i].flatten(),
94+
delimiter='\t')
95+
np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2)
96+
print(op, "results of this layer are correct within tolerence.")
97+
except Exception as e:
98+
np.set_printoptions(threshold=np.inf)
99+
mismatch_percent = (find_between(str(e), "(mismatch", "%)"))
100+
print(op, "mismatch with percentage {} %".format(mismatch_percent))
101+
102+
103+
if __name__ == '__main__':
104+
unittest.main()
105+
pass

onnx_tf/backend.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def prepare(cls,
6464
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
6565
common.logger.setLevel(logging_level)
6666
common.logger.handlers[0].setLevel(logging_level)
67-
common.sys_config.auto_cast=auto_cast
67+
common.sys_config.auto_cast = auto_cast
6868

69-
return cls.onnx_model_to_tensorflow_rep(model, strict)
69+
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)
7070

7171
@classmethod
72-
def onnx_model_to_tensorflow_rep(cls, model, strict):
72+
def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs):
7373
""" Convert ONNX model to TensorflowRep.
7474
7575
:param model: ONNX ModelProto object.
@@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
8686
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
8787
else:
8888
opset_import = model.opset_import
89-
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
89+
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict,
90+
**kwargs)
9091

9192
@classmethod
92-
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
93+
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
9394
""" Convert ONNX graph to TensorflowRep.
9495
9596
:param graph_def: ONNX GraphProto object.
9697
:param opset: ONNX OperatorSetIdProto list.
9798
:param strict: whether to enforce semantic equivalence between the original model
9899
and the converted tensorflow model.
100+
:kwargs: additional arguements to generate tensor_dict for model debugging
99101
:return: TensorflowRep object.
100102
"""
103+
# To generate tensor_dict or not, default is False
104+
gen_tensor_dict = kwargs[
105+
'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False
106+
# User provided input tensors, in the case the model inputs have unknown shapes
107+
input_tensor_dict = kwargs[
108+
'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict()
109+
101110
handlers = cls._get_handlers(opset)
102111

103112
# initializer: TensorProtos representing the values to initialize
104113
# a given tensor.
105114
# initialized: A list of names of the initialized tensors.
106115

107116
if graph_def.initializer:
117+
input_dict_items = cls._onnx_initializer_to_input_dict_items(
118+
graph_def.initializer)
108119
initialized = {init.name for init in graph_def.initializer}
109120
else:
121+
input_dict_items = []
110122
initialized = set()
111123

112124
module = BackendTFModule(handlers, opset, strict, graph_def, cls)
113125
signatures = dict()
114-
115126
for value_info in graph_def.input:
116127
if value_info.name in initialized:
117128
continue
118129
shape = list(
119130
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
120131
for d in value_info.type.tensor_type.shape.dim)
121132
value_info_name = value_info.name.replace(
122-
":", "_tf_") + "_" + get_unique_suffix(
123-
) if ":" in value_info.name else value_info.name
133+
":", "_tf_") + "_" + get_unique_suffix(
134+
) if ":" in value_info.name else value_info.name
124135

125-
tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name)
136+
tf_spec = tf.TensorSpec(
137+
shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type),
138+
value_info_name)
126139
signatures[value_info.name] = tf_spec
127140

141+
if gen_tensor_dict:
142+
x = tf.constant(
143+
0,
144+
dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type),
145+
name=value_info_name,
146+
shape=shape
147+
) if value_info.name not in input_tensor_dict else input_tensor_dict[
148+
value_info.name]
149+
input_dict_items.append((value_info_name, x))
150+
128151
tf_rep = TensorflowRep()
129152
tf_rep.inputs = [
130153
value_info.name
@@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
135158
module.outputs = tf_rep.outputs
136159
tf_rep.tf_module = module
137160
tf_rep.signatures = signatures
161+
tf_rep.tensor_dict = module.gen_tensor_dict(
162+
input_dict_items) if gen_tensor_dict else None
163+
138164
return tf_rep
139165

140166
@classmethod
@@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
148174
:param kwargs: Other args.
149175
:return: Outputs.
150176
"""
177+
151178
class TFModule(tf.Module):
179+
152180
def __init__(self, node):
153181
super(TFModule, self).__init__()
154182
self.node = node
@@ -171,13 +199,16 @@ def __call__(self, **input_dict):
171199
feed_dict_raw = dict(zip(node.inputs, inputs))
172200

173201
# TODO: is constant the best way for feeding inputs?
174-
input_dict = dict(
175-
[(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()])
202+
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
203+
])
176204

177205
module = TFModule(node)
178206

179207
output_vals = module(**input_dict)
180-
output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals]
208+
output_vals = [
209+
val.numpy() if isinstance(val, tf.Tensor) else val
210+
for val in output_vals
211+
]
181212

182213
return namedtupledict('Outputs', node.outputs)(*output_vals)
183214

@@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
231262
"""
232263
handlers = handlers or cls._get_handlers(opset)
233264
if handlers:
234-
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
265+
handler = handlers[node.domain].get(
266+
node.op_type, None) if node.domain in handlers else None
235267
if handler:
236268
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
237269

238-
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
270+
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(
271+
node.op_type))
239272

240273
@classmethod
241274
def _get_handlers(cls, opset):
@@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
293326
nodes_outputs.append(o_name)
294327
for node in subgraph.node:
295328
for i_name in node.input:
296-
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys():
329+
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
330+
):
297331
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
298332
onnx_node = OnnxNode(node)
299333
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
@@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
305339
return subgraph_tensor_dict
306340

307341
@classmethod
308-
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
342+
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
309343
"""
310344
Converts ONNX graph to TensorflowRep
311345
Args:
@@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
318352
"""
319353
# get the opset of the installed ONNX
320354
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
321-
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
355+
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs)
322356

323357

324358
prepare = TensorflowBackend.prepare

onnx_tf/backend_rep.py

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
1616
self._inputs = inputs or []
1717
self._outputs = outputs or []
1818
self._tensor_dict = tensor_dict or {}
19+
self._tf_module = None
1920

2021
@property
2122
def graph(self):
@@ -49,6 +50,14 @@ def tensor_dict(self):
4950
def tensor_dict(self, tensor_dict):
5051
self._tensor_dict = tensor_dict
5152

53+
@property
54+
def tf_module(self):
55+
return self._tf_module
56+
57+
@tf_module.setter
58+
def tf_module(self, tf_module):
59+
self._tf_module = tf_module
60+
5261
def run(self, inputs, **kwargs):
5362
""" Run TensorflowRep.
5463

onnx_tf/backend_tf_module.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tensorflow as tf
22
from onnx_tf.pb_wrapper import OnnxNode
33

4+
45
class BackendTFModule(tf.Module):
56

67
def __init__(self, handlers, opset, strict, graph_def, backend):
@@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1213
self.backend = backend
1314
self.outputs = []
1415

16+
@tf.function
17+
def gen_tensor_dict(self, input_dict_items):
18+
tensor_dict = dict(input_dict_items)
19+
20+
for node in self.graph_def.node:
21+
onnx_node = OnnxNode(node)
22+
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
23+
tensor_dict,
24+
self.handlers,
25+
opset=self.opset,
26+
strict=self.strict)
27+
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
28+
tensor_dict.update(curr_node_output_map)
29+
30+
return tensor_dict
31+
1532
@tf.function
1633
def __call__(self, **kwargs):
1734
tensor_dict = kwargs
@@ -26,8 +43,11 @@ def __call__(self, **kwargs):
2643

2744
for node in self.graph_def.node:
2845
onnx_node = OnnxNode(node)
29-
output_ops = self.backend._onnx_node_to_tensorflow_op(
30-
onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict)
46+
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
47+
tensor_dict,
48+
self.handlers,
49+
opset=self.opset,
50+
strict=self.strict)
3151
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
3252
tensor_dict.update(curr_node_output_map)
3353

0 commit comments

Comments
 (0)