1
- from onnx .defs import ONNX_DOMAIN
2
1
import tensorflow as tf
2
+ from onnx_tf .common import exception
3
+ from onnx_tf .common import get_variable_name
3
4
from onnx_tf .pb_wrapper import OnnxNode
4
5
5
6
6
7
class BackendTFModule (tf .Module ):
8
+ """ BackendTFModule is the tf.Module class used in backend.prepare,
9
+ tf_rep.export_graph and tf_rep.run
10
+ """
7
11
8
12
def __init__ (self , handlers , opset , strict , graph_def , backend ):
9
13
super (BackendTFModule , self ).__init__ ()
@@ -42,31 +46,34 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
42
46
43
47
# create tf.Variable for handlers that required to use variable in handler
44
48
def _create_handlers_variables (self , graph , vars_dict ):
45
- handlers = self .backend ._get_handlers (self .opset )
46
- for node in graph .node :
47
- handler = handlers [node .domain ].get (
48
- node .op_type , None ) if node .domain in handlers else None
49
- if handler and bool (handler .get_req_vars_template ()):
50
- for v_name , v_template in handler .get_req_vars_template ().items ():
51
- v_init , v_shape = v_template
52
- v_count = 0
53
- for var_name in vars_dict .keys ():
54
- v_count = v_count + 1 if var_name .startswith (v_name ) else v_count
55
- v_name = v_name + '_' + str (v_count )
56
- vars_dict [v_name ] = tf .Variable (v_init ,
57
- dtype = v_init .dtype ,
58
- shape = v_shape ,
59
- name = v_name )
60
- if node .op_type in ['Loop' , 'Scan' ]:
61
- onnx_node = OnnxNode (node )
62
- body = onnx_node .attrs ["body" ]
63
- vars_dict = self ._create_handlers_variables (body , vars_dict )
64
- elif node .op_type == 'If' :
65
- onnx_node = OnnxNode (node )
66
- then_branch = onnx_node .attrs ['then_branch' ]
67
- vars_dict = self ._create_handlers_variables (then_branch , vars_dict )
68
- else_branch = onnx_node .attrs ['else_branch' ]
69
- vars_dict = self ._create_handlers_variables (else_branch , vars_dict )
49
+ if self .handlers :
50
+ handlers = self .backend ._get_handlers (self .opset )
51
+ for node in graph .node :
52
+ handler = handlers [node .domain ].get (
53
+ node .op_type , None ) if node .domain in handlers else None
54
+ if handler and bool (
55
+ handler .get_req_vars_template (node , self .initializer_dict )):
56
+ for v_name , v_template in handler .get_req_vars_template (
57
+ node , self .initializer_dict ).items ():
58
+ v_init , v_shape = v_template
59
+ v_name = get_variable_name (node , v_name )
60
+ if v_name in vars_dict .keys ():
61
+ # found duplicated variable name due to non unique node name
62
+ exception .NON_UNIQUE_NODE_NAME_EXCEPT ()
63
+ vars_dict [v_name ] = tf .Variable (v_init ,
64
+ dtype = v_init .dtype ,
65
+ shape = v_shape ,
66
+ name = v_name )
67
+ if node .op_type in ['Loop' , 'Scan' ]:
68
+ onnx_node = OnnxNode (node )
69
+ body = onnx_node .attrs ["body" ]
70
+ vars_dict = self ._create_handlers_variables (body , vars_dict )
71
+ elif node .op_type == 'If' :
72
+ onnx_node = OnnxNode (node )
73
+ then_branch = onnx_node .attrs ['then_branch' ]
74
+ vars_dict = self ._create_handlers_variables (then_branch , vars_dict )
75
+ else_branch = onnx_node .attrs ['else_branch' ]
76
+ vars_dict = self ._create_handlers_variables (else_branch , vars_dict )
70
77
return vars_dict
71
78
72
79
@tf .function
@@ -85,11 +92,6 @@ def gen_tensor_dict(self, input_dict):
85
92
curr_node_output_map = dict (zip (onnx_node .outputs , output_ops ))
86
93
tensor_dict .update (curr_node_output_map )
87
94
88
- # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
89
- # TODO update this when we support handlers in other domain
90
- for _ , handler in self .handlers [ONNX_DOMAIN ].items ():
91
- handler .VAR_COUNT = 0
92
-
93
95
return tensor_dict
94
96
95
97
@tf .function
@@ -110,8 +112,40 @@ def __call__(self, **kwargs):
110
112
111
113
outputs = [tensor_dict [output ] for output in self .outputs ]
112
114
113
- # reset VAR_COUNT in handlers(currently all handlers are in ONNX_DOMAIN)
114
- # TODO update this when we support handlers in other domain
115
- for _ , handler in self .handlers [ONNX_DOMAIN ].items ():
116
- handler .VAR_COUNT = 0
115
+ return outputs
116
+
117
+
118
+ class TFModule (tf .Module ):
119
+ """ TFModule is the tf.Module class used in backend.run_node.
120
+ """
121
+
122
+ def __init__ (self , node , backend ):
123
+ super (TFModule , self ).__init__ ()
124
+ self .node = node
125
+ self .backend = backend
126
+ self .handlers = backend ._get_handlers (opset = None )
127
+ self .handler_variables = self ._create_handlers_variables (dict ())
128
+
129
+ def _create_handlers_variables (self , vars_dict ):
130
+ if self .handlers :
131
+ handler = self .handlers [self .node .domain ].get (
132
+ self .node .op_type ,
133
+ None ) if self .node .domain in self .handlers else None
134
+ if handler and bool (
135
+ handler .get_req_vars_template (self .node , self .node .attrs )):
136
+ for v_name , v_template in handler .get_req_vars_template (
137
+ self .node , self .node .attrs ).items ():
138
+ v_init , v_shape = v_template
139
+ v_name = get_variable_name (self .node , v_name )
140
+ vars_dict [v_name ] = tf .Variable (v_init ,
141
+ dtype = v_init .dtype ,
142
+ shape = v_shape ,
143
+ name = v_name )
144
+ return vars_dict
145
+
146
+ @tf .function
147
+ def __call__ (self , ** input_dict ):
148
+ input_dict .update (self .handler_variables )
149
+ outputs = self .backend ._onnx_node_to_tensorflow_op (self .node , input_dict ,
150
+ self .handlers )
117
151
return outputs
0 commit comments