Skip to content

Commit a742d29

Browse files
authored
Modify handlers variables creation process (#801)
1. Create unique handlers' variable name by adding node.name to it. If cannot create unique variable name with node.name then throw exception. 2. Allow handler to set the variable shape base on node.attrs values 3. Move TFModule class from backend.run_node to backend_tf_module.py 4. Create handlers' variables in TFModule.init Signed-off-by: Winnie Tsang <[email protected]>
1 parent 7e4802c commit a742d29

File tree

7 files changed

+118
-63
lines changed

7 files changed

+118
-63
lines changed

onnx_tf/backend.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from onnx_tf.common import supports_device as common_supports_device
2727
from onnx_tf.common.handler_helper import get_all_backend_handlers
2828
from onnx_tf.pb_wrapper import OnnxNode
29-
from onnx_tf.backend_tf_module import BackendTFModule
29+
from onnx_tf.backend_tf_module import BackendTFModule, TFModule
3030
import onnx_tf.common as common
3131

3232

@@ -205,16 +205,6 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
205205
:return: Outputs.
206206
"""
207207

208-
class TFModule(tf.Module):
209-
210-
def __init__(self, node):
211-
super(TFModule, self).__init__()
212-
self.node = node
213-
214-
@tf.function
215-
def __call__(self, **input_dict):
216-
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)
217-
218208
super(TensorflowBackend, cls).run_node(node, inputs, device)
219209
common.sys_config.device = device
220210

@@ -233,7 +223,7 @@ def __call__(self, **input_dict):
233223
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
234224
])
235225

236-
module = TFModule(node)
226+
module = TFModule(node, cls)
237227

238228
output_vals = module(**input_dict)
239229
output_vals = [

onnx_tf/backend_tf_module.py

+69-35
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from onnx.defs import ONNX_DOMAIN
21
import tensorflow as tf
2+
from onnx_tf.common import exception
3+
from onnx_tf.common import get_variable_name
34
from onnx_tf.pb_wrapper import OnnxNode
45

56

67
class BackendTFModule(tf.Module):
8+
""" BackendTFModule is the tf.Module class used in backend.prepare,
9+
tf_rep.export_graph and tf_rep.run
10+
"""
711

812
def __init__(self, handlers, opset, strict, graph_def, backend):
913
super(BackendTFModule, self).__init__()
@@ -42,31 +46,34 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
4246

4347
# create tf.Variable for handlers that required to use variable in handler
4448
def _create_handlers_variables(self, graph, vars_dict):
45-
handlers = self.backend._get_handlers(self.opset)
46-
for node in graph.node:
47-
handler = handlers[node.domain].get(
48-
node.op_type, None) if node.domain in handlers else None
49-
if handler and bool(handler.get_req_vars_template()):
50-
for v_name, v_template in handler.get_req_vars_template().items():
51-
v_init, v_shape = v_template
52-
v_count = 0
53-
for var_name in vars_dict.keys():
54-
v_count = v_count + 1 if var_name.startswith(v_name) else v_count
55-
v_name = v_name + '_' + str(v_count)
56-
vars_dict[v_name] = tf.Variable(v_init,
57-
dtype=v_init.dtype,
58-
shape=v_shape,
59-
name=v_name)
60-
if node.op_type in ['Loop', 'Scan']:
61-
onnx_node = OnnxNode(node)
62-
body = onnx_node.attrs["body"]
63-
vars_dict = self._create_handlers_variables(body, vars_dict)
64-
elif node.op_type == 'If':
65-
onnx_node = OnnxNode(node)
66-
then_branch = onnx_node.attrs['then_branch']
67-
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
68-
else_branch = onnx_node.attrs['else_branch']
69-
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
49+
if self.handlers:
50+
handlers = self.backend._get_handlers(self.opset)
51+
for node in graph.node:
52+
handler = handlers[node.domain].get(
53+
node.op_type, None) if node.domain in handlers else None
54+
if handler and bool(
55+
handler.get_req_vars_template(node, self.initializer_dict)):
56+
for v_name, v_template in handler.get_req_vars_template(
57+
node, self.initializer_dict).items():
58+
v_init, v_shape = v_template
59+
v_name = get_variable_name(node, v_name)
60+
if v_name in vars_dict.keys():
61+
# found duplicated variable name due to non unique node name
62+
exception.NON_UNIQUE_NODE_NAME_EXCEPT()
63+
vars_dict[v_name] = tf.Variable(v_init,
64+
dtype=v_init.dtype,
65+
shape=v_shape,
66+
name=v_name)
67+
if node.op_type in ['Loop', 'Scan']:
68+
onnx_node = OnnxNode(node)
69+
body = onnx_node.attrs["body"]
70+
vars_dict = self._create_handlers_variables(body, vars_dict)
71+
elif node.op_type == 'If':
72+
onnx_node = OnnxNode(node)
73+
then_branch = onnx_node.attrs['then_branch']
74+
vars_dict = self._create_handlers_variables(then_branch, vars_dict)
75+
else_branch = onnx_node.attrs['else_branch']
76+
vars_dict = self._create_handlers_variables(else_branch, vars_dict)
7077
return vars_dict
7178

7279
@tf.function
@@ -85,11 +92,6 @@ def gen_tensor_dict(self, input_dict):
8592
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
8693
tensor_dict.update(curr_node_output_map)
8794

88-
# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
89-
# TODO update this when we support handlers in other domain
90-
for _, handler in self.handlers[ONNX_DOMAIN].items():
91-
handler.VAR_COUNT = 0
92-
9395
return tensor_dict
9496

9597
@tf.function
@@ -110,8 +112,40 @@ def __call__(self, **kwargs):
110112

111113
outputs = [tensor_dict[output] for output in self.outputs]
112114

113-
# reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
114-
# TODO update this when we support handlers in other domain
115-
for _, handler in self.handlers[ONNX_DOMAIN].items():
116-
handler.VAR_COUNT = 0
115+
return outputs
116+
117+
118+
class TFModule(tf.Module):
119+
""" TFModule is the tf.Module class used in backend.run_node.
120+
"""
121+
122+
def __init__(self, node, backend):
123+
super(TFModule, self).__init__()
124+
self.node = node
125+
self.backend = backend
126+
self.handlers = backend._get_handlers(opset=None)
127+
self.handler_variables = self._create_handlers_variables(dict())
128+
129+
def _create_handlers_variables(self, vars_dict):
130+
if self.handlers:
131+
handler = self.handlers[self.node.domain].get(
132+
self.node.op_type,
133+
None) if self.node.domain in self.handlers else None
134+
if handler and bool(
135+
handler.get_req_vars_template(self.node, self.node.attrs)):
136+
for v_name, v_template in handler.get_req_vars_template(
137+
self.node, self.node.attrs).items():
138+
v_init, v_shape = v_template
139+
v_name = get_variable_name(self.node, v_name)
140+
vars_dict[v_name] = tf.Variable(v_init,
141+
dtype=v_init.dtype,
142+
shape=v_shape,
143+
name=v_name)
144+
return vars_dict
145+
146+
@tf.function
147+
def __call__(self, **input_dict):
148+
input_dict.update(self.handler_variables)
149+
outputs = self.backend._onnx_node_to_tensorflow_op(self.node, input_dict,
150+
self.handlers)
117151
return outputs

onnx_tf/common/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self):
3131
self.device = 'CPU'
3232

3333

34-
3534
sys_config = SysConfig()
3635

3736

@@ -183,6 +182,16 @@ def supports_device(device):
183182
return False
184183

185184

185+
def get_variable_name(node, var_name):
186+
""" Get variable name.
187+
:param node: ONNX NodeProto object
188+
:param var_name: name of the variable
189+
:return: unique variable name
190+
"""
191+
v_name = node.op_type.lower() + '_' + var_name
192+
return v_name + '_' + node.name.lower() if node.name else v_name
193+
194+
186195
CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
187196
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
188197
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"

onnx_tf/common/exception.py

+15
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,23 @@ def get_message(self, op, supported_dtypes):
8181
return self._message.format(op, supported_dtypes)
8282

8383

84+
class NonUniqueNodeNameException(object):
85+
86+
def __init__(self):
87+
super(NonUniqueNodeNameException, self).__init__()
88+
self._func = RuntimeError
89+
self._message = "Node name is not unique in your model. Please recreate your model with unique node name."
90+
91+
def __call__(self):
92+
raise self._func(self.get_message())
93+
94+
def get_message(self):
95+
return self._message.format()
96+
97+
8498
IGNORE_UNIMPLEMENTED = False
8599
OP_UNIMPLEMENTED_EXCEPT = OpUnimplementedException()
86100
OP_UNSUPPORTED_EXCEPT = OpUnsupportedException()
87101
CONST_NOT_FOUND_EXCEPT = ConstNotFoundException()
88102
DTYPE_NOT_CAST_EXCEPT = DtypeNotCastException()
103+
NONUNIQUE_NODE_NAME_EXCEPT = NonUniqueNodeNameException()

onnx_tf/handlers/backend/non_max_suppression.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import tensorflow as tf
22

3+
from onnx_tf.common import get_variable_name
34
from onnx_tf.common.tf_helper import tf_shape
45
from onnx_tf.handlers.backend_handler import BackendHandler
56
from onnx_tf.handlers.handler import onnx_op
67

78

89
@onnx_op("NonMaxSuppression")
910
class NonMaxSuppression(BackendHandler):
10-
var_prefix = 'non_max_suppression_result'
11+
var_name = 'result'
1112

1213
@classmethod
13-
def get_req_vars_template(cls):
14-
""" Get required variables template.
15-
16-
:return: Dict.
14+
def get_req_vars_template(cls, node, init_dict):
15+
""" Get required variables template, which is a
16+
dictionary of variable names with initial value and
17+
shape.
18+
:param node: ONNX NodeProto object.
19+
:param init_dict: initializer dictionary of the graph.
20+
:return: Dictionary.
1721
"""
1822
return {
19-
cls.var_prefix: [
23+
cls.var_name: [
2024
tf.constant([[0, 0, 0]], dtype=tf.int64),
2125
tf.TensorShape([None, 3])
2226
]
@@ -96,10 +100,9 @@ def create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
96100
result = output if tf.equal(batch_i, 0) and tf.equal(
97101
class_j, 0) else tf.concat([result, output], 0)
98102

99-
cls.VAR_COUNT = cls.VAR_COUNT + 1
100103
return result
101104

102-
result = tensor_dict[cls.var_prefix + '_' + str(cls.VAR_COUNT)]
105+
result = tensor_dict[get_variable_name(node, cls.var_name)]
103106
return [
104107
create_nodes(boxes, scores, max_output_boxes_per_class, iou_threshold,
105108
score_threshold, result)

onnx_tf/handlers/backend_handler.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ class BackendHandler(Handler):
2424
"""
2525

2626
TF_FUNC = None
27-
VAR_COUNT = 0
2827

2928
@classmethod
30-
def get_req_vars_template(cls):
31-
""" Get required variables template.
32-
33-
:return: Dict.
29+
def get_req_vars_template(cls, node, init_dict):
30+
""" Get required variables template, which is a
31+
dictionary of variable names with initial value and
32+
shape
33+
:param node: ONNX NodeProto object.
34+
:param init_dict: initializer dictionary of the graph.
35+
:return: Dictionary.
3436
"""
3537
return {}
3638

test/backend/test_dynamic_shape.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,14 @@ def test_non_max_suppression_with_if(self):
550550
"NonMaxSuppression",
551551
["boxes", "scores", "max_output_boxes_per_class", "iou_threshold"],
552552
["selected_indices_1"],
553-
center_point_box=0)
553+
center_point_box=0,
554+
name='NonMaxSuppression_1')
554555
non_max_suppression_node_2 = helper.make_node("NonMaxSuppression", [
555556
"boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
556557
"score_threshold"
557558
], ["selected_indices_2"],
558-
center_point_box=0)
559+
center_point_box=0,
560+
name='NonMaxSuppression_2')
559561

560562
then_graph = helper.make_graph(nodes=[non_max_suppression_node_1],
561563
name="then_graph",

0 commit comments

Comments
 (0)