Skip to content

Commit 4d63050

Browse files
author
Jonathan Sparling
committed
Add limited SoftmaxCrossEntropyLoss support
1 parent f9ebc35 commit 4d63050

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import tensorflow as tf
2+
3+
from onnx_tf.common import exception
4+
from onnx_tf.handlers.backend_handler import BackendHandler
5+
from onnx_tf.handlers.handler import onnx_op
6+
from onnx_tf.handlers.handler import tf_func
7+
8+
9+
@onnx_op("SoftmaxCrossEntropyLoss")
10+
@tf_func(tf.nn.sparse_softmax_cross_entropy_with_logits)
11+
class SoftmaxCrossEntropyLoss(BackendHandler):
12+
@classmethod
13+
def _common(cls, node, **kwargs):
14+
logits = kwargs["tensor_dict"][node.inputs[0]]
15+
labels = kwargs["tensor_dict"][node.inputs[1]]
16+
17+
labels_shape = tf.shape(labels)
18+
if labels_shape.shape[0] > 1:
19+
raise NotImplementedError(
20+
"SoftmaxCrossEntropyLoss support is limited to rank 1 label tensors."
21+
.format(spatial_size))
22+
23+
return [
24+
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
25+
]
26+
27+
@classmethod
28+
def version_12(cls, node, **kwargs):
29+
return cls._common(node, **kwargs)
30+
31+
@classmethod
32+
def version_13(cls, node, **kwargs):
33+
return cls._common(node, **kwargs)

test/backend/test_node.py

+10
Original file line numberDiff line numberDiff line change
@@ -3987,6 +3987,16 @@ def test_softplus(self):
39873987
np.log(np.exp(x) + 1),
39883988
decimal=5)
39893989

3990+
def test_softmax_cross_entropy_loss(self):
3991+
node_def = helper.make_node("SoftmaxCrossEntropyLoss", ["X", "Y"], ["Z"])
3992+
classes = 10
3993+
x = self._get_rnd_float32(shape=[1,classes])
3994+
y = self._get_rnd_int(0, classes-1, [1], np.int32)
3995+
output = run_node(node_def, [x, y])
3996+
np.testing.assert_almost_equal(output["Z"],
3997+
-np.log(np.exp(x)[0][y]/np.sum(np.exp(x))),
3998+
decimal=5)
3999+
39904000
def test_softsign(self):
39914001
node_def = helper.make_node("Softsign", ["X"], ["Y"])
39924002
x = self._get_rnd_float32(shape=[3, 4, 5])

0 commit comments

Comments
 (0)