Skip to content

Commit f616d65

Browse files
authored
Add model stepping test for Mnist (#734)
* Add model stepping test for Mnist Add model stepping test for Mnist using ONNX runtime. The assumption is that ONNX runtime is installed and the mnist model from ONNX model zoo is downloaded. Signed-off-by: Chin Huang <[email protected]> * add tensor_dict back in TFRep Signed-off-by: Chin Huang <[email protected]>
1 parent 0110e0e commit f616d65

File tree

4 files changed

+187
-19
lines changed

4 files changed

+187
-19
lines changed
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import unittest
7+
import numpy as np
8+
9+
import onnx
10+
from onnx import helper
11+
from onnx import TensorProto
12+
import tensorflow as tf
13+
import onnxruntime.backend as ort
14+
15+
import onnx_tf.backend as otf
16+
from onnx_tf.common import data_type
17+
18+
19+
def find_between(s, first, last):
20+
try:
21+
start = s.index(first)
22+
end = s.index(last) + len(last)
23+
return s[start:end]
24+
except ValueError:
25+
return ""
26+
27+
28+
class TestMnistModel(unittest.TestCase):
29+
# Make sure the onnx file path is correct, assuming copied to the
30+
# current directory
31+
model_path = 'mnist-8.onnx'
32+
33+
def test(self):
34+
_model = onnx.load(self.model_path)
35+
print("Total node count in model: ", len(_model.graph.node))
36+
37+
# The input tensors could be provided as constants
38+
# The example below illustrates such a dictionary could be
39+
# provided for models with unknown input shapes. Since
40+
# mnist has known input shape, we don't provide input tensors.
41+
# input_tensors = {'Input3': tf.constant(0, dtype = tf.float32,
42+
# name='Input3',
43+
# shape=[1, 1, 28, 28])}
44+
input_tensors = {}
45+
tensor_dict = otf.prepare(_model,
46+
gen_tensor_dict=True,
47+
input_tensor_dict=input_tensors).tensor_dict
48+
more_outputs = []
49+
output_to_check = []
50+
for node in _model.graph.node:
51+
# add the first output of each node to the model output
52+
output_tensor = None
53+
for i in range(len(_model.graph.value_info)):
54+
if _model.graph.value_info[i].name == node.output[0]:
55+
output_tensor = _model.graph.value_info[i]
56+
57+
for i in range(len(_model.graph.initializer)):
58+
if _model.graph.initializer[i].name == node.output[0]:
59+
output_tensor = _model.graph.initializer[i]
60+
61+
# assume the first output is a tensor
62+
tensor = tensor_dict[node.output[0]]
63+
output_tensor = helper.make_tensor_value_info(
64+
node.output[0], data_type.tf2onnx(tensor.dtype),
65+
tensor.shape) if output_tensor is None else output_tensor
66+
more_outputs.append(output_tensor)
67+
output_to_check.append(node.output[0])
68+
_model.graph.output.extend(more_outputs)
69+
70+
tf_rep = otf.prepare(_model)
71+
rt_rep = ort.prepare(_model)
72+
73+
# prepare input data
74+
mnist = tf.keras.datasets.mnist
75+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
76+
x_train, x_test = x_train / 255.0, x_test / 255.0
77+
sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32)
78+
79+
inputs = [sample]
80+
my_out = tf_rep.run(inputs)
81+
rt_out = rt_rep.run(inputs)
82+
83+
for op in output_to_check:
84+
for i in range(len(my_out)):
85+
# find the index of output in the list
86+
if my_out[op] is my_out[i]:
87+
88+
try:
89+
np.savetxt(op.replace("/", "__") + ".rt",
90+
rt_out[i].flatten(),
91+
delimiter='\t')
92+
np.savetxt(op.replace("/", "__") + ".tf",
93+
my_out[i].flatten(),
94+
delimiter='\t')
95+
np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2)
96+
print(op, "results of this layer are correct within tolerence.")
97+
except Exception as e:
98+
np.set_printoptions(threshold=np.inf)
99+
mismatch_percent = (find_between(str(e), "(mismatch", "%)"))
100+
print(op, "mismatch with percentage {} %".format(mismatch_percent))
101+
102+
103+
if __name__ == '__main__':
104+
unittest.main()
105+
pass

onnx_tf/backend.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def prepare(cls,
6464
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
6565
common.logger.setLevel(logging_level)
6666
common.logger.handlers[0].setLevel(logging_level)
67-
common.sys_config.auto_cast=auto_cast
67+
common.sys_config.auto_cast = auto_cast
6868

69-
return cls.onnx_model_to_tensorflow_rep(model, strict)
69+
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)
7070

7171
@classmethod
72-
def onnx_model_to_tensorflow_rep(cls, model, strict):
72+
def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs):
7373
""" Convert ONNX model to TensorflowRep.
7474
7575
:param model: ONNX ModelProto object.
@@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
8686
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
8787
else:
8888
opset_import = model.opset_import
89-
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
89+
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict,
90+
**kwargs)
9091

9192
@classmethod
92-
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
93+
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
9394
""" Convert ONNX graph to TensorflowRep.
9495
9596
:param graph_def: ONNX GraphProto object.
9697
:param opset: ONNX OperatorSetIdProto list.
9798
:param strict: whether to enforce semantic equivalence between the original model
9899
and the converted tensorflow model.
100+
:kwargs: additional arguements to generate tensor_dict for model debugging
99101
:return: TensorflowRep object.
100102
"""
103+
# To generate tensor_dict or not, default is False
104+
gen_tensor_dict = kwargs[
105+
'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False
106+
# User provided input tensors, in the case the model inputs have unknown shapes
107+
input_tensor_dict = kwargs[
108+
'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict()
109+
101110
handlers = cls._get_handlers(opset)
102111

103112
# initializer: TensorProtos representing the values to initialize
104113
# a given tensor.
105114
# initialized: A list of names of the initialized tensors.
106115

107116
if graph_def.initializer:
117+
input_dict_items = cls._onnx_initializer_to_input_dict_items(
118+
graph_def.initializer)
108119
initialized = {init.name for init in graph_def.initializer}
109120
else:
121+
input_dict_items = []
110122
initialized = set()
111123

112124
module = BackendTFModule(handlers, opset, strict, graph_def, cls)
113125
signatures = dict()
114-
115126
for value_info in graph_def.input:
116127
if value_info.name in initialized:
117128
continue
118129
shape = list(
119130
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
120131
for d in value_info.type.tensor_type.shape.dim)
121132
value_info_name = value_info.name.replace(
122-
":", "_tf_") + "_" + get_unique_suffix(
123-
) if ":" in value_info.name else value_info.name
133+
":", "_tf_") + "_" + get_unique_suffix(
134+
) if ":" in value_info.name else value_info.name
124135

125-
tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name)
136+
tf_spec = tf.TensorSpec(
137+
shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type),
138+
value_info_name)
126139
signatures[value_info.name] = tf_spec
127140

141+
if gen_tensor_dict:
142+
x = tf.constant(
143+
0,
144+
dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type),
145+
name=value_info_name,
146+
shape=shape
147+
) if value_info.name not in input_tensor_dict else input_tensor_dict[
148+
value_info.name]
149+
input_dict_items.append((value_info_name, x))
150+
128151
tf_rep = TensorflowRep()
129152
tf_rep.inputs = [
130153
value_info.name
@@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
135158
module.outputs = tf_rep.outputs
136159
tf_rep.tf_module = module
137160
tf_rep.signatures = signatures
161+
tf_rep.tensor_dict = module.gen_tensor_dict(
162+
input_dict_items) if gen_tensor_dict else None
163+
138164
return tf_rep
139165

140166
@classmethod
@@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
148174
:param kwargs: Other args.
149175
:return: Outputs.
150176
"""
177+
151178
class TFModule(tf.Module):
179+
152180
def __init__(self, node):
153181
super(TFModule, self).__init__()
154182
self.node = node
@@ -171,13 +199,16 @@ def __call__(self, **input_dict):
171199
feed_dict_raw = dict(zip(node.inputs, inputs))
172200

173201
# TODO: is constant the best way for feeding inputs?
174-
input_dict = dict(
175-
[(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()])
202+
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
203+
])
176204

177205
module = TFModule(node)
178206

179207
output_vals = module(**input_dict)
180-
output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals]
208+
output_vals = [
209+
val.numpy() if isinstance(val, tf.Tensor) else val
210+
for val in output_vals
211+
]
181212

182213
return namedtupledict('Outputs', node.outputs)(*output_vals)
183214

@@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
231262
"""
232263
handlers = handlers or cls._get_handlers(opset)
233264
if handlers:
234-
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
265+
handler = handlers[node.domain].get(
266+
node.op_type, None) if node.domain in handlers else None
235267
if handler:
236268
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
237269

238-
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
270+
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(
271+
node.op_type))
239272

240273
@classmethod
241274
def _get_handlers(cls, opset):
@@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
293326
nodes_outputs.append(o_name)
294327
for node in subgraph.node:
295328
for i_name in node.input:
296-
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys():
329+
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
330+
):
297331
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
298332
onnx_node = OnnxNode(node)
299333
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
@@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
305339
return subgraph_tensor_dict
306340

307341
@classmethod
308-
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
342+
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
309343
"""
310344
Converts ONNX graph to TensorflowRep
311345
Args:
@@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
318352
"""
319353
# get the opset of the installed ONNX
320354
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
321-
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
355+
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs)
322356

323357

324358
prepare = TensorflowBackend.prepare

onnx_tf/backend_rep.py

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
1616
self._inputs = inputs or []
1717
self._outputs = outputs or []
1818
self._tensor_dict = tensor_dict or {}
19+
self._tf_module = None
1920

2021
@property
2122
def graph(self):
@@ -49,6 +50,14 @@ def tensor_dict(self):
4950
def tensor_dict(self, tensor_dict):
5051
self._tensor_dict = tensor_dict
5152

53+
@property
54+
def tf_module(self):
55+
return self._tf_module
56+
57+
@tf_module.setter
58+
def tf_module(self, tf_module):
59+
self._tf_module = tf_module
60+
5261
def run(self, inputs, **kwargs):
5362
""" Run TensorflowRep.
5463

onnx_tf/backend_tf_module.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tensorflow as tf
22
from onnx_tf.pb_wrapper import OnnxNode
33

4+
45
class BackendTFModule(tf.Module):
56

67
def __init__(self, handlers, opset, strict, graph_def, backend):
@@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1213
self.backend = backend
1314
self.outputs = []
1415

16+
@tf.function
17+
def gen_tensor_dict(self, input_dict_items):
18+
tensor_dict = dict(input_dict_items)
19+
20+
for node in self.graph_def.node:
21+
onnx_node = OnnxNode(node)
22+
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
23+
tensor_dict,
24+
self.handlers,
25+
opset=self.opset,
26+
strict=self.strict)
27+
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
28+
tensor_dict.update(curr_node_output_map)
29+
30+
return tensor_dict
31+
1532
@tf.function
1633
def __call__(self, **kwargs):
1734
tensor_dict = kwargs
@@ -26,8 +43,11 @@ def __call__(self, **kwargs):
2643

2744
for node in self.graph_def.node:
2845
onnx_node = OnnxNode(node)
29-
output_ops = self.backend._onnx_node_to_tensorflow_op(
30-
onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict)
46+
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
47+
tensor_dict,
48+
self.handlers,
49+
opset=self.opset,
50+
strict=self.strict)
3151
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
3252
tensor_dict.update(curr_node_output_map)
3353

0 commit comments

Comments
 (0)