From 4a1c85191f9ebae29d053152783113b4b5e17137 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Sun, 15 Jul 2018 14:04:17 -0700 Subject: [PATCH 1/4] fix GRU and hacky support for linear_before_reset --- onnx_tf/handlers/backend/gru.py | 155 +++++++++++++++++++++++++++++--- 1 file changed, 144 insertions(+), 11 deletions(-) diff --git a/onnx_tf/handlers/backend/gru.py b/onnx_tf/handlers/backend/gru.py index 4d30e393a..fb1be6012 100644 --- a/onnx_tf/handlers/backend/gru.py +++ b/onnx_tf/handlers/backend/gru.py @@ -7,6 +7,126 @@ from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op from .rnn_mixin import RNNMixin +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + + +_BIAS_VARIABLE_NAME = "bias" +_WEIGHTS_VARIABLE_NAME = "kernel" + + +class GRUCellWithLinearBeforeReset(tf.contrib.rnn.LayerRNNCell): + """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). + Args: + num_units: int, The number of units in the GRU cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight and + projection matrices. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None): + super(GRUCellWithLinearBeforeReset, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._activation = activation or math_ops.tanh + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % inputs_shape) + + input_depth = inputs_shape[1].value + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.constant_initializer(1.0, dtype=self.dtype))) + self._candidate_bias_rbh = self.add_variable( + "candidate_rbh/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_bias_wbh = self.add_variable( + "candidate_wbh/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_kernel_rh = self.add_variable( + "candidate_rh/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_units], + initializer=self._kernel_initializer) + self._candidate_kernel_wh = self.add_variable( + "candidate_wh/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + + self.built = True + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + b_in, b_hn = (self._candidate_bias_rbh, self._candidate_bias_wbh) + + linear_gate_state = math_ops.matmul(state, self._candidate_kernel_rh) + linear_gate_state = nn_ops.bias_add(linear_gate_state, self._candidate_bias_rbh) + r_state = r * linear_gate_state + + candidate = math_ops.matmul(inputs, self._candidate_kernel_wh) + candidate = nn_ops.bias_add(candidate, self._candidate_bias_wbh) + + c = self._activation(candidate + r_state) + new_h = u * state + (1 - u) * c + return new_h, new_h @onnx_op("GRU") @@ -18,9 +138,6 @@ def args_check(cls, node, **kwargs): num_directions = 2 if direction == "bidirectional" else 1 if "clip" in node.attrs: exception.OP_UNSUPPORTED_EXCEPT("GRU with clip", "Tensorflow") - if node.attrs.get("linear_before_reset", 0): - exception.OP_UNSUPPORTED_EXCEPT("GRU with linear_before_reset", - "Tensorflow") if "activations" in node.attrs: activations = list(map(lambda x: x.lower(), node.attrs["activations"])) if activations[0] != "sigmoid": @@ -63,11 +180,16 @@ def _custom_getter(cls, if names[-2] == "gates": new_w = tf.transpose(tf.concat([w_r, w_z], 0)) new_r = tf.transpose(tf.concat([r_r, r_z], 0)) - elif names[-2] == "candidate": + elif names[-2] == "candidate" or names[-2] == "candidate_rh" or names[-2] == "candidate_wh": new_w = tf.transpose(w_h) new_r = tf.transpose(r_h) - kernel = tf.concat([new_w, new_r], 0) - return kernel + if names[-2] == "candidate_rh": + return new_r + elif names[-2] == "candidate_wh": + return new_w + else: + kernel = tf.concat([new_w, new_r], 0) + return kernel if names[-1] == "bias": if len(node.inputs) >= 4: # onnx Wb[zrh], Rb[zrh] @@ -81,10 +203,15 @@ def _custom_getter(cls, if names[-2] == "gates": w_b = tf.transpose(tf.concat([w_b_r, w_b_z], 0)) r_b = tf.transpose(tf.concat([r_b_r, r_b_z], 0)) - elif names[-2] == "candidate": + elif names[-2] == "candidate" or names[-2] == "candidate_rbh" or names[-2] == "candidate_wbh": w_b = tf.transpose(w_b_h) r_b = tf.transpose(r_b_h) - return tf.add(w_b, r_b) + if names[-2] == "candidate_rbh": + return r_b + elif names[-2] == "candidate_wbh": + return w_b + else: + return tf.add(w_b, r_b) return getter(name, *args, **kwargs) return getter(name, *args, **kwargs) @@ -105,7 +232,7 @@ def _common(cls, node, **kwargs): # process input if it comes from other previous cell # which has shape [seq_length, num_directions, batch_size, hidden_size] if len(input_shape) == 4 and input_shape[1] == 1: - x = tf.squeeze(x) + x = tf.squeeze(x, axis=[1]) sequence_length = None if input_size >= 5 and node.inputs[4] in tensor_dict: @@ -158,13 +285,19 @@ def _common(cls, node, **kwargs): rnn_kwargs["time_major"] = True rnn_kwargs["dtype"] = tf.float32 - outputs, states = cls.rnn(x, tf.nn.rnn_cell.GRUCell, cell_kwargs, + if node.attrs.get("linear_before_reset", 0): + cell_class = GRUCellWithLinearBeforeReset + else: + cell_class = tf.nn.rnn_cell.GRUCell + + outputs, states = cls.rnn(x, cell_class, cell_kwargs, rnn_kwargs, tf_activations, direction) if num_directions == 1: state = states[0] h = tf.expand_dims(state, 0) - output = tf.expand_dims(outputs, 1) + # output = tf.expand_dims(outputs, 1) + output = outputs else: state_fw = states[0][0] state_bw = states[1][0] From 2a7a4e319963eb4ec7edd7c0fc5b5b83d45d9cae Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Sun, 15 Jul 2018 14:36:46 -0700 Subject: [PATCH 2/4] fix some bugs in bidi rnn --- onnx_tf/handlers/backend/gru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_tf/handlers/backend/gru.py b/onnx_tf/handlers/backend/gru.py index fb1be6012..b7d898151 100644 --- a/onnx_tf/handlers/backend/gru.py +++ b/onnx_tf/handlers/backend/gru.py @@ -240,7 +240,7 @@ def _common(cls, node, **kwargs): cell_kwargs = {} - tf_activations = [tf.nn.tanh] + tf_activations = [tf.nn.tanh] * num_directions if "activations" in node.attrs: activations = list(map(lambda x: x.lower(), node.attrs["activations"])) activation_alpha = node.attrs.get("activation_alpha", [None] * 4) From 11a5c4d2000dedaf913589fc5d0fdd0abceaf1d2 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Sun, 15 Jul 2018 18:55:58 -0700 Subject: [PATCH 3/4] fix bidi-rnn output shape --- onnx_tf/handlers/backend/gru.py | 5 ++--- onnx_tf/handlers/backend/lstm.py | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/onnx_tf/handlers/backend/gru.py b/onnx_tf/handlers/backend/gru.py index b7d898151..3feb768fe 100644 --- a/onnx_tf/handlers/backend/gru.py +++ b/onnx_tf/handlers/backend/gru.py @@ -290,6 +290,7 @@ def _common(cls, node, **kwargs): else: cell_class = tf.nn.rnn_cell.GRUCell + print("Input to GRU: %s" % str(x.shape)) outputs, states = cls.rnn(x, cell_class, cell_kwargs, rnn_kwargs, tf_activations, direction) @@ -306,9 +307,7 @@ def _common(cls, node, **kwargs): h_fw = tf.expand_dims(state_fw, 0) h_bw = tf.expand_dims(state_bw, 0) h = tf.concat((h_fw, h_bw), axis=0) - output_fw = tf.expand_dims(output_fw, 1) - output_bw = tf.expand_dims(output_bw, 1) - output = tf.concat((output_fw, output_bw), axis=1) + output = tf.concat((output_fw, output_bw), axis=-1) return [output, h] if output_sequence == 0 else [h] diff --git a/onnx_tf/handlers/backend/lstm.py b/onnx_tf/handlers/backend/lstm.py index de9f169fc..0dadcaccb 100644 --- a/onnx_tf/handlers/backend/lstm.py +++ b/onnx_tf/handlers/backend/lstm.py @@ -195,9 +195,7 @@ def _common(cls, node, **kwargs): h_fw = tf.expand_dims(state_fw[1], 0) h_bw = tf.expand_dims(state_bw[1], 0) h = tf.concat((h_fw, h_bw), axis=0) - output_fw = tf.expand_dims(output_fw, 1) - output_bw = tf.expand_dims(output_bw, 1) - output = tf.concat((output_fw, output_bw), axis=1) + output = tf.concat((output_fw, output_bw), axis=-1) return [output, h, c] if output_sequence == 0 else [h, c] From 175d3f3643653b893b9f14a32a60081a43f8d5c8 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 19 Jul 2018 12:10:40 -0700 Subject: [PATCH 4/4] gardening --- onnx_tf/handlers/backend/gru.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnx_tf/handlers/backend/gru.py b/onnx_tf/handlers/backend/gru.py index 3feb768fe..6e79339fa 100644 --- a/onnx_tf/handlers/backend/gru.py +++ b/onnx_tf/handlers/backend/gru.py @@ -290,14 +290,12 @@ def _common(cls, node, **kwargs): else: cell_class = tf.nn.rnn_cell.GRUCell - print("Input to GRU: %s" % str(x.shape)) outputs, states = cls.rnn(x, cell_class, cell_kwargs, rnn_kwargs, tf_activations, direction) if num_directions == 1: state = states[0] h = tf.expand_dims(state, 0) - # output = tf.expand_dims(outputs, 1) output = outputs else: state_fw = states[0][0]