@@ -64,12 +64,12 @@ def prepare(cls,
64
64
super (TensorflowBackend , cls ).prepare (model , device , ** kwargs )
65
65
common .logger .setLevel (logging_level )
66
66
common .logger .handlers [0 ].setLevel (logging_level )
67
- common .sys_config .auto_cast = auto_cast
67
+ common .sys_config .auto_cast = auto_cast
68
68
69
- return cls .onnx_model_to_tensorflow_rep (model , strict )
69
+ return cls .onnx_model_to_tensorflow_rep (model , strict , ** kwargs )
70
70
71
71
@classmethod
72
- def onnx_model_to_tensorflow_rep (cls , model , strict ):
72
+ def onnx_model_to_tensorflow_rep (cls , model , strict , ** kwargs ):
73
73
""" Convert ONNX model to TensorflowRep.
74
74
75
75
:param model: ONNX ModelProto object.
@@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
86
86
opset_import = [make_opsetid (defs .ONNX_DOMAIN , 1 )]
87
87
else :
88
88
opset_import = model .opset_import
89
- return cls ._onnx_graph_to_tensorflow_rep (model .graph , opset_import , strict )
89
+ return cls ._onnx_graph_to_tensorflow_rep (model .graph , opset_import , strict ,
90
+ ** kwargs )
90
91
91
92
@classmethod
92
- def _onnx_graph_to_tensorflow_rep (cls , graph_def , opset , strict ):
93
+ def _onnx_graph_to_tensorflow_rep (cls , graph_def , opset , strict , ** kwargs ):
93
94
""" Convert ONNX graph to TensorflowRep.
94
95
95
96
:param graph_def: ONNX GraphProto object.
96
97
:param opset: ONNX OperatorSetIdProto list.
97
98
:param strict: whether to enforce semantic equivalence between the original model
98
99
and the converted tensorflow model.
100
+ :kwargs: additional arguements to generate tensor_dict for model debugging
99
101
:return: TensorflowRep object.
100
102
"""
103
+ # To generate tensor_dict or not, default is False
104
+ gen_tensor_dict = kwargs [
105
+ 'gen_tensor_dict' ] if 'gen_tensor_dict' in kwargs else False
106
+ # User provided input tensors, in the case the model inputs have unknown shapes
107
+ input_tensor_dict = kwargs [
108
+ 'input_tensor_dict' ] if 'input_tensor_dict' in kwargs else dict ()
109
+
101
110
handlers = cls ._get_handlers (opset )
102
111
103
112
# initializer: TensorProtos representing the values to initialize
104
113
# a given tensor.
105
114
# initialized: A list of names of the initialized tensors.
106
115
107
116
if graph_def .initializer :
117
+ input_dict_items = cls ._onnx_initializer_to_input_dict_items (
118
+ graph_def .initializer )
108
119
initialized = {init .name for init in graph_def .initializer }
109
120
else :
121
+ input_dict_items = []
110
122
initialized = set ()
111
123
112
124
module = BackendTFModule (handlers , opset , strict , graph_def , cls )
113
125
signatures = dict ()
114
-
115
126
for value_info in graph_def .input :
116
127
if value_info .name in initialized :
117
128
continue
118
129
shape = list (
119
130
d .dim_value if (d .dim_value > 0 and d .dim_param == "" ) else None
120
131
for d in value_info .type .tensor_type .shape .dim )
121
132
value_info_name = value_info .name .replace (
122
- ":" , "_tf_" ) + "_" + get_unique_suffix (
123
- ) if ":" in value_info .name else value_info .name
133
+ ":" , "_tf_" ) + "_" + get_unique_suffix (
134
+ ) if ":" in value_info .name else value_info .name
124
135
125
- tf_spec = tf .TensorSpec (shape , data_type .onnx2tf (value_info .type .tensor_type .elem_type ), value_info_name )
136
+ tf_spec = tf .TensorSpec (
137
+ shape , data_type .onnx2tf (value_info .type .tensor_type .elem_type ),
138
+ value_info_name )
126
139
signatures [value_info .name ] = tf_spec
127
140
141
+ if gen_tensor_dict :
142
+ x = tf .constant (
143
+ 0 ,
144
+ dtype = data_type .onnx2tf (value_info .type .tensor_type .elem_type ),
145
+ name = value_info_name ,
146
+ shape = shape
147
+ ) if value_info .name not in input_tensor_dict else input_tensor_dict [
148
+ value_info .name ]
149
+ input_dict_items .append ((value_info_name , x ))
150
+
128
151
tf_rep = TensorflowRep ()
129
152
tf_rep .inputs = [
130
153
value_info .name
@@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
135
158
module .outputs = tf_rep .outputs
136
159
tf_rep .tf_module = module
137
160
tf_rep .signatures = signatures
161
+ tf_rep .tensor_dict = module .gen_tensor_dict (
162
+ input_dict_items ) if gen_tensor_dict else None
163
+
138
164
return tf_rep
139
165
140
166
@classmethod
@@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
148
174
:param kwargs: Other args.
149
175
:return: Outputs.
150
176
"""
177
+
151
178
class TFModule (tf .Module ):
179
+
152
180
def __init__ (self , node ):
153
181
super (TFModule , self ).__init__ ()
154
182
self .node = node
@@ -171,13 +199,16 @@ def __call__(self, **input_dict):
171
199
feed_dict_raw = dict (zip (node .inputs , inputs ))
172
200
173
201
# TODO: is constant the best way for feeding inputs?
174
- input_dict = dict (
175
- [( x [ 0 ], tf . constant ( x [ 1 ])) for x in feed_dict_raw . items () ])
202
+ input_dict = dict ([( x [ 0 ], tf . constant ( x [ 1 ])) for x in feed_dict_raw . items ()
203
+ ])
176
204
177
205
module = TFModule (node )
178
206
179
207
output_vals = module (** input_dict )
180
- output_vals = [val .numpy () if isinstance (val , tf .Tensor ) else val for val in output_vals ]
208
+ output_vals = [
209
+ val .numpy () if isinstance (val , tf .Tensor ) else val
210
+ for val in output_vals
211
+ ]
181
212
182
213
return namedtupledict ('Outputs' , node .outputs )(* output_vals )
183
214
@@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
231
262
"""
232
263
handlers = handlers or cls ._get_handlers (opset )
233
264
if handlers :
234
- handler = handlers [node .domain ].get (node .op_type , None ) if node .domain in handlers else None
265
+ handler = handlers [node .domain ].get (
266
+ node .op_type , None ) if node .domain in handlers else None
235
267
if handler :
236
268
return handler .handle (node , tensor_dict = tensor_dict , strict = strict )
237
269
238
- raise BackendIsNotSupposedToImplementIt ("{} is not implemented." .format (node .op_type ))
270
+ raise BackendIsNotSupposedToImplementIt ("{} is not implemented." .format (
271
+ node .op_type ))
239
272
240
273
@classmethod
241
274
def _get_handlers (cls , opset ):
@@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
293
326
nodes_outputs .append (o_name )
294
327
for node in subgraph .node :
295
328
for i_name in node .input :
296
- if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict .keys ():
329
+ if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict .keys (
330
+ ):
297
331
subgraph_tensor_dict [i_name ] = tensor_dict [i_name ]
298
332
onnx_node = OnnxNode (node )
299
333
output_ops = cls ._onnx_node_to_tensorflow_op (onnx_node ,
@@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
305
339
return subgraph_tensor_dict
306
340
307
341
@classmethod
308
- def onnx_graph_to_tensorflow_rep (cls , graph_def , strict = True ):
342
+ def onnx_graph_to_tensorflow_rep (cls , graph_def , strict = True , ** kwargs ):
309
343
"""
310
344
Converts ONNX graph to TensorflowRep
311
345
Args:
@@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
318
352
"""
319
353
# get the opset of the installed ONNX
320
354
opset = [make_opsetid (defs .ONNX_DOMAIN , defs .onnx_opset_version ())]
321
- return cls ._onnx_graph_to_tensorflow_rep (graph_def , opset , strict )
355
+ return cls ._onnx_graph_to_tensorflow_rep (graph_def , opset , strict , ** kwargs )
322
356
323
357
324
358
prepare = TensorflowBackend .prepare
0 commit comments