Skip to content

Commit c63d435

Browse files
Use user defined device choice during prepare model (#726)
1. Added sys_config.device and set it to user defined device choice "CUDA" or "CPU" in prepare and run_node 2. Updated conv_mixin.py to use sys_config.device and eliminate unnecessary transpose 3. Updated pool_mixin.py to eliminate the mandatory conversion of input x to channel last(NHWC) format. Added function to convert NHWC indices to NCHW indices for MaxPool_with_Argmax to fix issue #719 4. Updated unpool_mixin.py to eliminate the mandatory conversion of input x to channel last(NHWC) format 5. Updated dilated_pooling.py to process NCHW format input instead of NHWC format for all pooling operators, except MaxPool_with_Argmax and MaxPool_with_dilation_not_equal_to_1_and _spatial_size_equal_to_2. (tf.nn.maxpool_with_argmax and tf.nn.dilation2d only support NHWC format) 6. Added dynamic_shape test for Maxpool_with_Argmax 7. Set device in run_node for operators that behave differently in NCHW/NHWC format in test_node.py Signed-off-by: Winnie Tsang <[email protected]> Co-authored-by: Chin Huang <[email protected]>
1 parent 651c90f commit c63d435

12 files changed

+706
-552
lines changed

doc/API.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ _params_:
2020
`model` : The ONNX model to be converted.
2121

2222

23-
`device` : The device to execute this model on.
23+
`device` : The device to execute this model on. It can be either CPU (default) or CUDA.
2424

2525

2626
`strict` : Whether to enforce semantic equivalence between the original model

doc/CLI.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ optional arguments:
4040
Output directory.
4141
4242
backend arguments (onnx -> tf):
43-
--device DEVICE The device to execute this model on. (from
44-
onnx_tf.backend.prepare)
43+
--device DEVICE The device to execute this model on. It can be either
44+
CPU (default) or CUDA. (from onnx_tf.backend.prepare)
4545
--strict STRICT Whether to enforce semantic equivalence between the
4646
original model and the converted tensorflow model,
4747
defaults to True (yes, enforce semantic equivalence).

onnx_tf/backend.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def prepare(cls,
4949
the converted representation.
5050
5151
:param model: The ONNX model to be converted.
52-
:param device: The device to execute this model on.
52+
:param device: The device to execute this model on. It can be either CPU (default) or CUDA.
5353
:param strict: Whether to enforce semantic equivalence between the original model
5454
and the converted tensorflow model, defaults to True (yes, enforce semantic equivalence).
5555
Changing to False is strongly discouraged.
@@ -65,6 +65,7 @@ def prepare(cls,
6565
common.logger.setLevel(logging_level)
6666
common.logger.handlers[0].setLevel(logging_level)
6767
common.sys_config.auto_cast = auto_cast
68+
common.sys_config.device = device
6869

6970
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)
7071

@@ -184,6 +185,7 @@ def __call__(self, **input_dict):
184185
return cls._onnx_node_to_tensorflow_op(self.node, input_dict)
185186

186187
super(TensorflowBackend, cls).run_node(node, inputs, device)
188+
common.sys_config.device = device
187189

188190
node = OnnxNode(node)
189191
input_tensors = []

onnx_tf/common/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class SysConfig:
2828

2929
def __init__(self):
3030
self.auto_cast = False
31+
self.device = 'CPU'
32+
3133

3234

3335
sys_config = SysConfig()
@@ -160,7 +162,7 @@ def get_data_format(x_rank):
160162
sp_dim_string = "".join(reversed(sp_dim_lst))
161163
storage_format = "NC" + sp_dim_string
162164

163-
if supports_device("CUDA"):
165+
if sys_config.device == "CUDA":
164166
compute_format = "NC" + sp_dim_string
165167
else:
166168
compute_format = "N" + sp_dim_string + "C"
@@ -169,7 +171,6 @@ def get_data_format(x_rank):
169171

170172
def supports_device(device):
171173
""" Check if support target device.
172-
173174
:param device: CUDA or CPU.
174175
:return: If supports.
175176
"""

onnx_tf/common/pooling_helper.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def py_pool(input, kernel_shape, strides=None, dilations=None,
158158

159159
def _loop_over_output(batch, channel):
160160
dims = [range(output_sp_shape[d]) for d in range(spatial_size)]
161+
image_size = 1
162+
for d in input_shape[2:]:
163+
image_size *= d
161164
for counters in itertools.product(*dims):
162165
input_ranges = []
163166
for dim in range(spatial_size):
@@ -189,7 +192,10 @@ def _loop_over_output(batch, channel):
189192
else:
190193
if val > maxval:
191194
maxval = val
192-
ind = 0
195+
# batch_offset = batch * C * image_size
196+
# channel_offset = channel * image_size
197+
# ind = batch_offset + channel_offset
198+
ind = image_size * (batch * input_shape[1] + channel)
193199
for i in range(spatial_size):
194200
coef = 1
195201
for j in range(i+1, spatial_size):

onnx_tf/handlers/backend/conv_mixin.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import tensorflow as tf
22

3+
from onnx_tf.common import exception
34
from onnx_tf.common import get_data_format
45
from onnx_tf.common import get_perm_from_formats
5-
from onnx_tf.common import supports_device
6-
from onnx_tf.common import exception
76
from onnx_tf.common.tf_helper import tf_shape
7+
from onnx_tf.common import sys_config
88
from .broadcast_mixin import BroadcastMixin
99
from .pad_mixin import PadMixin
1010

@@ -31,7 +31,6 @@ def conv(cls, node, input_dict, transpose=False):
3131
x_shape = tf_shape(x, tf.int32)
3232
spatial_size = x_rank - 2
3333

34-
support_cuda = supports_device("CUDA")
3534
storage_format, compute_format = get_data_format(x_rank)
3635
compute_c_idx = compute_format.find("C")
3736
spatial_format = "".join([d for d in compute_format if d not in ["N", "C"]])
@@ -94,7 +93,7 @@ def conv(cls, node, input_dict, transpose=False):
9493

9594
weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)
9695

97-
if support_cuda:
96+
if sys_config.device == 'CUDA':
9897
xs = tf.split(x, num_or_size_splits=group, axis=1)
9998
else:
10099
x = tf.transpose(x,
@@ -236,7 +235,7 @@ def conv(cls, node, input_dict, transpose=False):
236235
]
237236

238237
if len(node.inputs) == 2:
239-
if support_cuda:
238+
if sys_config.device == 'CUDA':
240239
output = tf.concat(convolved, axis=1)
241240
else:
242241
output = tf.concat(convolved, axis=-1)
@@ -247,7 +246,7 @@ def conv(cls, node, input_dict, transpose=False):
247246
bias = input_dict[node.inputs[2]]
248247
bias = cls.explicit_broadcast([x, bias], compute_c_idx)
249248

250-
if support_cuda:
249+
if sys_config.device == 'CUDA':
251250
output = tf.concat(convolved, axis=1)
252251
output = tf.add(output, bias)
253252
else:

0 commit comments

Comments
 (0)