Skip to content

Commit 7bab839

Browse files
masakistanEC2 Default User
authored and
EC2 Default User
committed
Merge branch 'master' into fix_instance_norm
Signed-off-by: Stanley Fujimoto <[email protected]>
2 parents ac891dc + a742d29 commit 7bab839

14 files changed

+504
-135
lines changed

doc/CLI.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ optional arguments:
2222
### Convert:
2323

2424
#### From ONNX to Tensorflow:
25-
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
25+
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`
2626

2727
More information: `onnx-tf convert -h`
2828
```

doc/CLI_template.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ More information: `onnx-tf -h`
1414
### Convert:
1515

1616
#### From ONNX to Tensorflow:
17-
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb`
17+
`onnx-tf convert -i /path/to/input.onnx -o /path/to/output`
1818

1919
More information: `onnx-tf convert -h`
2020
```

onnx_tf/backend.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from onnx_tf.common import supports_device as common_supports_device
2727
from onnx_tf.common.handler_helper import get_all_backend_handlers
2828
from onnx_tf.pb_wrapper import OnnxNode
29-
from onnx_tf.backend_tf_module import BackendTFModule
29+
from onnx_tf.backend_tf_module import BackendTFModule, TFModule
3030
import onnx_tf.common as common
3131

3232

@@ -160,8 +160,39 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
160160
tf_rep.signatures = signatures
161161
tf_rep.tensor_dict = module.gen_tensor_dict(
162162
input_dict) if gen_tensor_dict else None
163+
tf_rep.onnx_op_list = cls._get_onnx_op_list(graph_def)
163164
return tf_rep
164165

166+
@classmethod
167+
def _get_onnx_op_list(cls, graph_def):
168+
""" Get ONNX operator counts of the model.
169+
170+
:param graph_def: ONNX GraphProto object.
171+
:return: Dictionary of all operators counts in the model.
172+
"""
173+
174+
def get_onnx_op_from_graph_and_subgraph(graph, op_list):
175+
for node in graph.node:
176+
op_list[node.op_type] = 1 if node.op_type not in op_list.keys(
177+
) else op_list[node.op_type] + 1
178+
if node.op_type in ['Loop', 'Scan']:
179+
onnx_node = OnnxNode(node)
180+
body = onnx_node.attrs["body"]
181+
op_list = get_onnx_op_from_graph_and_subgraph(body, op_list)
182+
elif node.op_type == 'If':
183+
onnx_node = OnnxNode(node)
184+
then_branch = onnx_node.attrs['then_branch']
185+
op_list = get_onnx_op_from_graph_and_subgraph(then_branch, op_list)
186+
else_branch = onnx_node.attrs['else_branch']
187+
op_list = get_onnx_op_from_graph_and_subgraph(else_branch, op_list)
188+
return op_list
189+
190+
op_list = get_onnx_op_from_graph_and_subgraph(graph_def, dict())
191+
sorted_op_list = dict()
192+
for key in sorted(op_list):
193+
sorted_op_list[key] = op_list[key]
194+
return sorted_op_list
195+
165196
@classmethod
166197
def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
167198
""" Run ONNX node.
@@ -174,16 +205,6 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
174205
:return: Outputs.
175206
"""
176207

177-
class TFModule(tf.Module):
178-
179-
def __init__(self, node):
180-
super(TFModule, self).__init__()
181-
self.node = node
182-
183-
@tf.function
184-
def __call__(self, **input_dict):
185-
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)
186-
187208
super(TensorflowBackend, cls).run_node(node, inputs, device)
188209
common.sys_config.device = device
189210

@@ -202,7 +223,7 @@ def __call__(self, **input_dict):
202223
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
203224
])
204225

205-
module = TFModule(node)
226+
module = TFModule(node, cls)
206227

207228
output_vals = module(**input_dict)
208229
output_vals = [

onnx_tf/backend_rep.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def tensor_dict(self):
5050
def tensor_dict(self, tensor_dict):
5151
self._tensor_dict = tensor_dict
5252

53+
@property
54+
def onnx_op_list(self):
55+
return self._onnx_op_list
56+
57+
@onnx_op_list.setter
58+
def onnx_op_list(self, onnx_op_list):
59+
self._onnx_op_list = onnx_op_list
60+
5361
@property
5462
def tf_module(self):
5563
return self._tf_module
@@ -80,11 +88,13 @@ def run(self, inputs, **kwargs):
8088
# single input
8189
feed_dict = dict([(self.inputs[0], inputs)])
8290

83-
input_dict = dict(
84-
[(x[0], tf.constant(x[1])) for x in feed_dict.items()])
91+
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict.items()])
8592

8693
output_values = self.tf_module(**input_dict)
87-
output_values = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_values]
94+
output_values = [
95+
val.numpy() if isinstance(val, tf.Tensor) else val
96+
for val in output_values
97+
]
8898

8999
return namedtupledict('Outputs', self.outputs)(*output_values)
90100

@@ -99,4 +109,8 @@ def export_graph(self, path):
99109
100110
:returns: none.
101111
"""
102-
tf.saved_model.save(self.tf_module, path, signatures=self.tf_module.__call__.get_concrete_function(**self.signatures))
112+
tf.saved_model.save(
113+
self.tf_module,
114+
path,
115+
signatures=self.tf_module.__call__.get_concrete_function(
116+
**self.signatures))

onnx_tf/backend_tf_module.py

+78
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import tensorflow as tf
2+
from onnx_tf.common import exception
3+
from onnx_tf.common import get_variable_name
24
from onnx_tf.pb_wrapper import OnnxNode
35

46

57
class BackendTFModule(tf.Module):
8+
""" BackendTFModule is the tf.Module class used in backend.prepare,
9+
tf_rep.export_graph and tf_rep.run
10+
"""
611

712
def __init__(self, handlers, opset, strict, graph_def, backend):
813
super(BackendTFModule, self).__init__()
@@ -14,6 +19,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1419
self.outputs = []
1520
self.initializer_dict = self._get_initializer_from_graph_and_subgraphs(
1621
self.graph_def, dict())
22+
self.handler_variables = self._create_handlers_variables(
23+
self.graph_def, dict())
1724

1825
# get initializer from the main graph and all subgraphs in loop or if or scan
1926
# into tensor_dict
@@ -37,10 +44,43 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
3744
else_branch, graph_tensor_dict)
3845
return graph_tensor_dict
3946

47+
# create tf.Variable for handlers that required to use variable in handler
48+
def _create_handlers_variables(self, graph, vars_dict):
49+
if self.handlers:
50+
handlers = self.backend._get_handlers(self.opset)
51+
for node in graph.node:
52+
handler = handlers[node.domain].get(
53+
node.op_type, None) if node.domain in handlers else None
54+
if handler and bool(
55+
handler.get_req_vars_template(node, self.initializer_dict)):
56+
for v_name, v_template in handler.get_req_vars_template(
57+
node, self.initializer_dict).items():
58+
v_init, v_shape = v_template
59+
v_name = get_variable_name(node, v_name)
60+
if v_name in vars_dict.keys():
61+
# found duplicated variable name due to non unique node name
62+
exception.NON_UNIQUE_NODE_NAME_EXCEPT()
63+
vars_dict[v_name] = tf.Variable(v_init,
64+
dtype=v_init.dtype,
65+
shape=v_shape,
66+
name=v_name)
67+
if node.op_type in ['Loop', 'Scan']:
68+
onnx_node = OnnxNode(node)
69+
body = onnx_node.attrs["body"]
70+
vars_dict = self._create_handlers_variables(body, vars_dict)
71+
elif node.op_type == 'If':
72+
onnx_node = OnnxNode(node)
73+
then_branch = onnx_node.attrs['then_branch']
74+
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
75+
else_branch = onnx_node.attrs['else_branch']
76+
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
77+
return vars_dict
78+
4079
@tf.function
4180
def gen_tensor_dict(self, input_dict):
4281
tensor_dict = dict(input_dict)
4382
tensor_dict.update(self.initializer_dict)
83+
tensor_dict.update(self.handler_variables)
4484

4585
for node in self.graph_def.node:
4686
onnx_node = OnnxNode(node)
@@ -58,6 +98,7 @@ def gen_tensor_dict(self, input_dict):
5898
def __call__(self, **kwargs):
5999
tensor_dict = kwargs
60100
tensor_dict.update(self.initializer_dict)
101+
tensor_dict.update(self.handler_variables)
61102

62103
for node in self.graph_def.node:
63104
onnx_node = OnnxNode(node)
@@ -70,4 +111,41 @@ def __call__(self, **kwargs):
70111
tensor_dict.update(curr_node_output_map)
71112

72113
outputs = [tensor_dict[output] for output in self.outputs]
114+
115+
return outputs
116+
117+
118+
class TFModule(tf.Module):
119+
""" TFModule is the tf.Module class used in backend.run_node.
120+
"""
121+
122+
def __init__(self, node, backend):
123+
super(TFModule, self).__init__()
124+
self.node = node
125+
self.backend = backend
126+
self.handlers = backend._get_handlers(opset=None)
127+
self.handler_variables = self._create_handlers_variables(dict())
128+
129+
def _create_handlers_variables(self, vars_dict):
130+
if self.handlers:
131+
handler = self.handlers[self.node.domain].get(
132+
self.node.op_type,
133+
None) if self.node.domain in self.handlers else None
134+
if handler and bool(
135+
handler.get_req_vars_template(self.node, self.node.attrs)):
136+
for v_name, v_template in handler.get_req_vars_template(
137+
self.node, self.node.attrs).items():
138+
v_init, v_shape = v_template
139+
v_name = get_variable_name(self.node, v_name)
140+
vars_dict[v_name] = tf.Variable(v_init,
141+
dtype=v_init.dtype,
142+
shape=v_shape,
143+
name=v_name)
144+
return vars_dict
145+
146+
@tf.function
147+
def __call__(self, **input_dict):
148+
input_dict.update(self.handler_variables)
149+
outputs = self.backend._onnx_node_to_tensorflow_op(self.node, input_dict,
150+
self.handlers)
73151
return outputs

onnx_tf/common/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self):
3131
self.device = 'CPU'
3232

3333

34-
3534
sys_config = SysConfig()
3635

3736

@@ -183,6 +182,16 @@ def supports_device(device):
183182
return False
184183

185184

185+
def get_variable_name(node, var_name):
186+
""" Get variable name.
187+
:param node: ONNX NodeProto object
188+
:param var_name: name of the variable
189+
:return: unique variable name
190+
"""
191+
v_name = node.op_type.lower() + '_' + var_name
192+
return v_name + '_' + node.name.lower() if node.name else v_name
193+
194+
186195
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
187196
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
188197
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"

onnx_tf/common/exception.py

+15
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,23 @@ def get_message(self, op, supported_dtypes):
8181
return self._message.format(op, supported_dtypes)
8282

8383

84+
class NonUniqueNodeNameException(object):
85+
86+
def __init__(self):
87+
super(NonUniqueNodeNameException, self).__init__()
88+
self._func = RuntimeError
89+
self._message = "Node name is not unique in your model. Please recreate your model with unique node name."
90+
91+
def __call__(self):
92+
raise self._func(self.get_message())
93+
94+
def get_message(self):
95+
return self._message.format()
96+
97+
8498
IGNORE_UNIMPLEMENTED = False
8599
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
86100
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
87101
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
88102
DTYPE_NOT_CAST_EXCEPT = DtypeNotCastException()
103+
NONUNIQUE_NODE_NAME_EXCEPT = NonUniqueNodeNameException()

onnx_tf/handlers/backend/conv_mixin.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def conv(cls, node, input_dict, transpose=False):
4646

4747
if "kernel_shape" in node.attrs.keys():
4848
kernel_shape = node.attrs["kernel_shape"]
49-
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
50-
"kernel_shape "
51-
"attr of convolution does not match the actual weight "
52-
"passed to this operation, attr {}, actual {}").format(
53-
kernel_shape,
54-
in_weights.get_shape().as_list())
49+
if in_weights.get_shape().is_fully_defined():
50+
assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
51+
"kernel_shape "
52+
"attr of convolution does not match the actual weight "
53+
"passed to this operation, attr {}, actual {}").format(
54+
kernel_shape,
55+
in_weights.get_shape().as_list())
5556
else:
56-
kernel_shape = in_weights.get_shape().as_list()[2:]
57+
kernel_shape = tf_shape(in_weights, tf.int32)[2:]
5758

5859
weights = tf.transpose(in_weights, perm)
5960
dilations = node.attrs.get("dilations", [1] * spatial_size)

onnx_tf/handlers/backend/dilated_pooling.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,12 @@ def dilated_maxpool_with_argmax(self, force_custom_impl=False):
601601

602602
# if there was padding, recalculate the returned index
603603
# to exclude the padding
604-
count_nonzero_op = np.count_nonzero if self.is_known_shape else tf.math.count_nonzero
605-
if count_nonzero_op(self.pads) != 0:
606-
new_ind = self._calc_argmax_without_padding(new_ind)
604+
if self.is_known_shape:
605+
if np.count_nonzero(self.pads) != 0:
606+
new_ind = self._calc_argmax_without_padding(new_ind)
607+
else:
608+
new_ind = tf.where(tf.not_equal(tf.math.count_nonzero(self.pads), 0),
609+
self._calc_argmax_without_padding(new_ind), new_ind)
607610

608611
return (pooled, new_ind)
609612

onnx_tf/handlers/backend/dropout.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ def _common(cls, node, **kwargs):
1717
x = tensor_dict[node.inputs[0]]
1818
attrs = copy.deepcopy(node.attrs)
1919

20-
if cls.SINCE_VERSION < 7:
20+
if cls.SINCE_VERSION < 7 and attrs.pop("is_test", 0) == 0:
2121
attrs["keep_prob"] = 1 - attrs.pop("ratio", 0.5)
2222
return [cls.make_tensor_from_onnx_node(node, attrs=attrs, **kwargs)]
23-
elif cls.SINCE_VERSION < 12 or attrs.pop("is_test", 0) == 1: # for Opset 7, 10
23+
elif cls.SINCE_VERSION < 12 : # for Opset 7, 10
24+
# at inference mode, is_test attribute is always set to 1
25+
# dropout at inference mode is a no-op
2426
return [x]
2527
else: # for Opset 12, 13
2628
# ratio and training_mode are optional and passed as inputs
@@ -30,7 +32,7 @@ def _common(cls, node, **kwargs):
3032
training_mode = False # default is false
3133
if len(node.inputs) == 3:
3234
training_mode = tensor_dict[node.inputs[2]]
33-
35+
3436
return_mask = len(node.outputs) == 2 # if there are 2 outputs, mask is requested
3537
if ratio == 0 or training_mode is False: # Inferencing
3638
if return_mask is True:

0 commit comments

Comments
 (0)