1
1
import tensorflow as tf
2
+ from onnx_tf .common import exception
3
+ from onnx_tf .common import get_variable_name
2
4
from onnx_tf .pb_wrapper import OnnxNode
3
5
4
6
5
7
class BackendTFModule (tf .Module ):
8
+ """ BackendTFModule is the tf.Module class used in backend.prepare,
9
+ tf_rep.export_graph and tf_rep.run
10
+ """
6
11
7
12
def __init__ (self , handlers , opset , strict , graph_def , backend ):
8
13
super (BackendTFModule , self ).__init__ ()
@@ -14,6 +19,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
14
19
self .outputs = []
15
20
self .initializer_dict = self ._get_initializer_from_graph_and_subgraphs (
16
21
self .graph_def , dict ())
22
+ self .handler_variables = self ._create_handlers_variables (
23
+ self .graph_def , dict ())
17
24
18
25
# get initializer from the main graph and all subgraphs in loop or if or scan
19
26
# into tensor_dict
@@ -37,10 +44,43 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
37
44
else_branch , graph_tensor_dict )
38
45
return graph_tensor_dict
39
46
47
+ # create tf.Variable for handlers that required to use variable in handler
48
+ def _create_handlers_variables (self , graph , vars_dict ):
49
+ if self .handlers :
50
+ handlers = self .backend ._get_handlers (self .opset )
51
+ for node in graph .node :
52
+ handler = handlers [node .domain ].get (
53
+ node .op_type , None ) if node .domain in handlers else None
54
+ if handler and bool (
55
+ handler .get_req_vars_template (node , self .initializer_dict )):
56
+ for v_name , v_template in handler .get_req_vars_template (
57
+ node , self .initializer_dict ).items ():
58
+ v_init , v_shape = v_template
59
+ v_name = get_variable_name (node , v_name )
60
+ if v_name in vars_dict .keys ():
61
+ # found duplicated variable name due to non unique node name
62
+ exception .NON_UNIQUE_NODE_NAME_EXCEPT ()
63
+ vars_dict [v_name ] = tf .Variable (v_init ,
64
+ dtype = v_init .dtype ,
65
+ shape = v_shape ,
66
+ name = v_name )
67
+ if node .op_type in ['Loop' , 'Scan' ]:
68
+ onnx_node = OnnxNode (node )
69
+ body = onnx_node .attrs ["body" ]
70
+ vars_dict = self ._create_handlers_variables (body , vars_dict )
71
+ elif node .op_type == 'If' :
72
+ onnx_node = OnnxNode (node )
73
+ then_branch = onnx_node .attrs ['then_branch' ]
74
+ vars_dict = self ._create_handlers_variables (then_branch , vars_dict )
75
+ else_branch = onnx_node .attrs ['else_branch' ]
76
+ vars_dict = self ._create_handlers_variables (else_branch , vars_dict )
77
+ return vars_dict
78
+
40
79
@tf .function
41
80
def gen_tensor_dict (self , input_dict ):
42
81
tensor_dict = dict (input_dict )
43
82
tensor_dict .update (self .initializer_dict )
83
+ tensor_dict .update (self .handler_variables )
44
84
45
85
for node in self .graph_def .node :
46
86
onnx_node = OnnxNode (node )
@@ -58,6 +98,7 @@ def gen_tensor_dict(self, input_dict):
58
98
def __call__ (self , ** kwargs ):
59
99
tensor_dict = kwargs
60
100
tensor_dict .update (self .initializer_dict )
101
+ tensor_dict .update (self .handler_variables )
61
102
62
103
for node in self .graph_def .node :
63
104
onnx_node = OnnxNode (node )
@@ -70,4 +111,41 @@ def __call__(self, **kwargs):
70
111
tensor_dict .update (curr_node_output_map )
71
112
72
113
outputs = [tensor_dict [output ] for output in self .outputs ]
114
+
115
+ return outputs
116
+
117
+
118
+ class TFModule (tf .Module ):
119
+ """ TFModule is the tf.Module class used in backend.run_node.
120
+ """
121
+
122
+ def __init__ (self , node , backend ):
123
+ super (TFModule , self ).__init__ ()
124
+ self .node = node
125
+ self .backend = backend
126
+ self .handlers = backend ._get_handlers (opset = None )
127
+ self .handler_variables = self ._create_handlers_variables (dict ())
128
+
129
+ def _create_handlers_variables (self , vars_dict ):
130
+ if self .handlers :
131
+ handler = self .handlers [self .node .domain ].get (
132
+ self .node .op_type ,
133
+ None ) if self .node .domain in self .handlers else None
134
+ if handler and bool (
135
+ handler .get_req_vars_template (self .node , self .node .attrs )):
136
+ for v_name , v_template in handler .get_req_vars_template (
137
+ self .node , self .node .attrs ).items ():
138
+ v_init , v_shape = v_template
139
+ v_name = get_variable_name (self .node , v_name )
140
+ vars_dict [v_name ] = tf .Variable (v_init ,
141
+ dtype = v_init .dtype ,
142
+ shape = v_shape ,
143
+ name = v_name )
144
+ return vars_dict
145
+
146
+ @tf .function
147
+ def __call__ (self , ** input_dict ):
148
+ input_dict .update (self .handler_variables )
149
+ outputs = self .backend ._onnx_node_to_tensorflow_op (self .node , input_dict ,
150
+ self .handlers )
73
151
return outputs
0 commit comments