Skip to content

Commit d3ef981

Browse files
winstonawsChoiByungWook
authored andcommitted
Allow TensorFlow json serializer to accept dicts with ndarray values (#404)
1 parent 665c30f commit d3ef981

File tree

4 files changed

+32
-36
lines changed

4 files changed

+32
-36
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
=========
77

88
* enhancement: Local Mode: add training environment variables for AWS region and job name
9+
* bug-fix: default TensorFlow json serializer accepts dict of numpy arrays
910

1011
1.11.0
1112
======

src/sagemaker/predictor.py

+6-25
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import csv
1717
import json
1818
import numpy as np
19+
import six
1920
from six import StringIO, BytesIO
2021

2122
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
@@ -237,48 +238,28 @@ def __call__(self, data):
237238
Returns:
238239
object: Serialized data used for the request.
239240
"""
240-
if isinstance(data, np.ndarray):
241-
if not data.size > 0:
242-
raise ValueError("empty array can't be serialized")
243-
return _json_serialize_numpy_array(data)
244-
245-
if isinstance(data, list):
246-
if not len(data) > 0:
247-
raise ValueError("empty array can't be serialized")
248-
return _json_serialize_python_object(data)
249-
250241
if isinstance(data, dict):
251-
if not len(data.keys()) > 0:
252-
raise ValueError("empty dictionary can't be serialized")
253-
return _json_serialize_python_object(data)
242+
# convert each value in dict from a numpy array to a list if necessary, so they can be json serialized
243+
return json.dumps({k: _ndarray_to_list(v) for k, v in six.iteritems(data)})
254244

255245
# files and buffers
256246
if hasattr(data, 'read'):
257247
return _json_serialize_from_buffer(data)
258248

259-
raise ValueError("Unable to handle input format: {}".format(type(data)))
249+
return json.dumps(_ndarray_to_list(data))
260250

261251

262252
json_serializer = _JsonSerializer()
263253

264254

265-
def _json_serialize_numpy_array(data):
266-
# numpy arrays can't be serialized but we know they have uniform type
267-
return _json_serialize_python_object(data.tolist())
268-
269-
270-
def _json_serialize_python_object(data):
271-
return _json_serialize_object(data)
255+
def _ndarray_to_list(data):
256+
return data.tolist() if isinstance(data, np.ndarray) else data
272257

273258

274259
def _json_serialize_from_buffer(buff):
275260
return buff.read()
276261

277262

278-
def _json_serialize_object(data):
279-
return json.dumps(data)
280-
281-
282263
class _JsonDeserializer(object):
283264
def __init__(self):
284265
self.accept = CONTENT_TYPE_JSON

tests/unit/test_predictor.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,8 @@ def test_json_serializer_numpy_valid_2dimensional():
4040
assert result == '[[1, 2, 3], [3, 4, 5]]'
4141

4242

43-
def test_json_serializer_numpy_invalid_empty():
44-
with pytest.raises(ValueError) as invalid_input:
45-
json_serializer(np.array([]))
46-
47-
assert "empty array" in str(invalid_input)
43+
def test_json_serializer_empty():
44+
assert json_serializer(np.array([])) == '[]'
4845

4946

5047
def test_json_serializer_python_array():
@@ -62,15 +59,11 @@ def test_json_serializer_python_dictionary():
6259

6360

6461
def test_json_serializer_python_invalid_empty():
65-
with pytest.raises(ValueError) as error:
66-
json_serializer([])
67-
assert "empty array" in str(error)
62+
assert json_serializer([]) == '[]'
6863

6964

7065
def test_json_serializer_python_dictionary_invalid_empty():
71-
with pytest.raises(ValueError) as error:
72-
json_serializer({})
73-
assert "empty dictionary" in str(error)
66+
assert json_serializer({}) == '{}'
7467

7568

7669
def test_json_serializer_csv_buffer():

tests/unit/test_tf_predictor.py

+21
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,24 @@ def mock_response(expected_response, sagemaker_session, content_type):
336336
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.return_value = {
337337
'ContentType': content_type,
338338
'Body': io.BytesIO(expected_response)}
339+
340+
341+
def test_json_serialize_dict():
342+
data = {'tensor1': [1, 2, 3], 'tensor2': [4, 5, 6]}
343+
serialized = tf_json_serializer(data)
344+
# deserialize again for assertion, since dict order is not guaranteed
345+
deserialized = json.loads(serialized)
346+
assert deserialized == data
347+
348+
349+
def test_json_serialize_dict_with_numpy():
350+
data = {'tensor1': np.asarray([1, 2, 3]), 'tensor2': np.asarray([4, 5, 6])}
351+
serialized = tf_json_serializer(data)
352+
# deserialize again for assertion, since dict order is not guaranteed
353+
deserialized = json.loads(serialized)
354+
assert deserialized == {'tensor1': [1, 2, 3], 'tensor2': [4, 5, 6]}
355+
356+
357+
def test_json_serialize_numpy():
358+
data = np.asarray([[1, 2, 3], [4, 5, 6]])
359+
assert tf_json_serializer(data) == '[[1, 2, 3], [4, 5, 6]]'

0 commit comments

Comments
 (0)