Skip to content

Commit 95f9329

Browse files
committed
recursive sort dictionary
1 parent 81d080a commit 95f9329

File tree

7 files changed

+331
-81
lines changed

7 files changed

+331
-81
lines changed

python/pydantic_core/_pydantic_core.pyi

+7-7
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class SchemaSerializer:
305305
exclude_unset: bool = False,
306306
exclude_defaults: bool = False,
307307
exclude_none: bool = False,
308-
sort_keys: bool = False,
308+
sort_keys: Literal['recursive', 'top-level', 'unsorted'] = 'unsorted',
309309
round_trip: bool = False,
310310
warnings: bool | Literal['none', 'warn', 'error'] = True,
311311
fallback: Callable[[Any], Any] | None = None,
@@ -327,7 +327,7 @@ class SchemaSerializer:
327327
exclude_defaults: Whether to exclude fields that are equal to their default value.
328328
exclude_none: Whether to exclude fields that have a value of `None`.
329329
round_trip: Whether to enable serialization and validation round-trip support.
330-
sort_keys: Whether to sort dictionary keys at the root level.
330+
sort_keys: Whether to sort dictionary keys, either `'recursive'`, `'top-level'`, or `'unsorted'`.
331331
warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors,
332332
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
333333
fallback: A function to call when an unknown value is encountered,
@@ -354,7 +354,7 @@ class SchemaSerializer:
354354
exclude_defaults: bool = False,
355355
exclude_none: bool = False,
356356
round_trip: bool = False,
357-
sort_keys: bool = False,
357+
sort_keys: Literal['recursive', 'top-level', 'unsorted'] = 'unsorted',
358358
warnings: bool | Literal['none', 'warn', 'error'] = True,
359359
fallback: Callable[[Any], Any] | None = None,
360360
serialize_as_any: bool = False,
@@ -374,7 +374,7 @@ class SchemaSerializer:
374374
exclude_defaults: Whether to exclude fields that are equal to their default value.
375375
exclude_none: Whether to exclude fields that have a value of `None`.
376376
round_trip: Whether to enable serialization and validation round-trip support.
377-
sort_keys: Whether to sort dictionary keys at the root level.
377+
sort_keys: Whether to sort dictionary keys, either `'recursive'`, `'top-level'`, or `'unsorted'`.
378378
warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors,
379379
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
380380
fallback: A function to call when an unknown value is encountered,
@@ -402,7 +402,7 @@ def to_json(
402402
by_alias: bool = True,
403403
exclude_none: bool = False,
404404
round_trip: bool = False,
405-
sort_keys: bool = False,
405+
sort_keys: Literal['recursive', 'top-level', 'unsorted'] = 'unsorted',
406406
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
407407
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
408408
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',
@@ -424,7 +424,7 @@ def to_json(
424424
by_alias: Whether to use the alias names of fields.
425425
exclude_none: Whether to exclude fields that have a value of `None`.
426426
round_trip: Whether to enable serialization and validation round-trip support.
427-
sort_keys: Whether to sort dictionary keys at the root level.
427+
sort_keys: Whether to sort dictionary keys, either `'recursive'`, `'top-level'`, or `'unsorted'`.
428428
timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`.
429429
bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`.
430430
inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`.
@@ -483,7 +483,7 @@ def to_jsonable_python(
483483
by_alias: bool = True,
484484
exclude_none: bool = False,
485485
round_trip: bool = False,
486-
sort_keys: bool = False,
486+
sort_keys: Literal['recursive', 'top-level', 'unsorted'] = 'unsorted',
487487
timedelta_mode: Literal['iso8601', 'float'] = 'iso8601',
488488
bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8',
489489
inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants',

src/errors/validation_exception.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::build_tools::py_schema_error_type;
1818
use crate::errors::LocItem;
1919
use crate::get_pydantic_version;
2020
use crate::input::InputType;
21-
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState};
21+
use crate::serializers::{DuckTypingSerMode, Extra, SerMode, SerializationState, SortKeysMode};
2222
use crate::tools::{safe_repr, write_truncated_to_limited_bytes, SchemaDict};
2323

2424
use super::line_error::ValLineError;
@@ -347,7 +347,7 @@ impl ValidationError {
347347
None,
348348
false,
349349
false,
350-
false,
350+
&SortKeysMode::Unsorted,
351351
true,
352352
None,
353353
DuckTypingSerMode::SchemaBased,

src/serializers/extra.rs

+58-6
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl SerializationState {
8686
by_alias: Option<bool>,
8787
exclude_none: bool,
8888
round_trip: bool,
89-
sort_keys: bool,
89+
sort_keys: &'py SortKeysMode,
9090
serialize_unknown: bool,
9191
fallback: Option<&'py Bound<'_, PyAny>>,
9292
duck_typing_ser_mode: DuckTypingSerMode,
@@ -128,7 +128,7 @@ pub(crate) struct Extra<'a> {
128128
pub exclude_defaults: bool,
129129
pub exclude_none: bool,
130130
pub round_trip: bool,
131-
pub sort_keys: bool,
131+
pub sort_keys: &'a SortKeysMode,
132132
pub config: &'a SerializationConfig,
133133
pub rec_guard: &'a SerRecursionState,
134134
// the next two are used for union logic
@@ -155,7 +155,7 @@ impl<'a> Extra<'a> {
155155
exclude_defaults: bool,
156156
exclude_none: bool,
157157
round_trip: bool,
158-
sort_keys: bool,
158+
sort_keys: &'a SortKeysMode,
159159
config: &'a SerializationConfig,
160160
rec_guard: &'a SerRecursionState,
161161
serialize_unknown: bool,
@@ -241,7 +241,7 @@ pub(crate) struct ExtraOwned {
241241
exclude_defaults: bool,
242242
exclude_none: bool,
243243
round_trip: bool,
244-
sort_keys: bool,
244+
sort_keys: SortKeysMode,
245245
config: SerializationConfig,
246246
rec_guard: SerRecursionState,
247247
check: SerCheck,
@@ -263,7 +263,7 @@ impl ExtraOwned {
263263
exclude_defaults: extra.exclude_defaults,
264264
exclude_none: extra.exclude_none,
265265
round_trip: extra.round_trip,
266-
sort_keys: extra.sort_keys,
266+
sort_keys: *extra.sort_keys,
267267
config: extra.config.clone(),
268268
rec_guard: extra.rec_guard.clone(),
269269
check: extra.check,
@@ -286,7 +286,7 @@ impl ExtraOwned {
286286
exclude_defaults: self.exclude_defaults,
287287
exclude_none: self.exclude_none,
288288
round_trip: self.round_trip,
289-
sort_keys: self.sort_keys,
289+
sort_keys: &self.sort_keys,
290290
config: &self.config,
291291
rec_guard: &self.rec_guard,
292292
check: self.check,
@@ -387,6 +387,58 @@ impl From<bool> for WarningsMode {
387387
}
388388
}
389389

390+
// #[derive(Debug, Clone, Copy, Eq, PartialEq)]
391+
#[derive(Debug, Clone, Copy)]
392+
pub enum SortKeysMode {
393+
Recursive,
394+
TopLevel,
395+
Unsorted,
396+
}
397+
398+
impl<'py> FromPyObject<'py> for SortKeysMode {
399+
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<SortKeysMode> {
400+
if let Ok(str_mode) = ob.extract::<&str>() {
401+
match str_mode {
402+
"recursive" => Ok(Self::Recursive),
403+
"top-level" => Ok(Self::TopLevel),
404+
"unsorted" => Ok(Self::Unsorted),
405+
_ => Err(PyValueError::new_err(
406+
"Invalid sort_keys parameter, should be `'recursive'`, `'top-level'`, `'unsorted'`",
407+
)),
408+
}
409+
} else {
410+
Err(PyTypeError::new_err(
411+
"Invalid warnings parameter, should be `'none'`, `'warn'`, `'error'` or a `bool`",
412+
))
413+
}
414+
}
415+
}
416+
417+
impl From<&str> for SortKeysMode {
418+
fn from(s: &str) -> Self {
419+
match s {
420+
"recursive" => SortKeysMode::Recursive,
421+
"top-level" => SortKeysMode::TopLevel,
422+
"unsorted" => SortKeysMode::Unsorted,
423+
_ => SortKeysMode::Unsorted,
424+
}
425+
}
426+
}
427+
428+
impl<'py> IntoPyObject<'py> for &'_ SortKeysMode {
429+
type Target = PyString;
430+
type Output = Bound<'py, PyString>;
431+
type Error = Infallible;
432+
433+
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
434+
match self {
435+
SortKeysMode::Recursive => Ok(intern!(py, "recursive").clone()),
436+
SortKeysMode::TopLevel => Ok(intern!(py, "top-level").clone()),
437+
SortKeysMode::Unsorted => Ok(intern!(py, "unsorted").clone()),
438+
}
439+
}
440+
}
441+
390442
#[cfg_attr(debug_assertions, derive(Debug))]
391443
pub(crate) struct CollectWarnings {
392444
mode: WarningsMode,

src/serializers/fields.rs

+97-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::PydanticSerializationUnexpectedValue;
1313

1414
use super::computed_fields::ComputedFields;
1515
use super::errors::py_err_se_err;
16-
use super::extra::Extra;
16+
use super::extra::{Extra, SortKeysMode};
1717
use super::filter::SchemaFilter;
1818
use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer};
1919
use super::shared::PydanticSerializer;
@@ -156,7 +156,7 @@ impl GeneralFieldsSerializer {
156156
let output_dict = PyDict::new(py);
157157
let mut used_req_fields: usize = 0;
158158

159-
if !extra.sort_keys {
159+
if matches!(extra.sort_keys, SortKeysMode::Unsorted) {
160160
for result in main_iter {
161161
let (key, value) = result?;
162162
if let Some(is_required) =
@@ -201,6 +201,23 @@ impl GeneralFieldsSerializer {
201201
}
202202
}
203203

204+
fn sort_dict_recursive<'py>(py: Python<'py>, value: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyDict>> {
205+
let dict = value.downcast::<PyDict>()?;
206+
let mut items = dict_items(dict).collect::<PyResult<Vec<_>>>()?;
207+
items.sort_by_cached_key(|(key, _)| key_str(key).unwrap_or_default().to_string());
208+
209+
let sorted_dict = PyDict::new(py);
210+
for (k, v) in items {
211+
if v.downcast::<PyDict>().is_ok() {
212+
let sorted_v = Self::sort_dict_recursive(py, &v)?;
213+
sorted_dict.set_item(k, sorted_v)?;
214+
} else {
215+
sorted_dict.set_item(k, v)?;
216+
}
217+
}
218+
Ok(sorted_dict)
219+
}
220+
204221
fn process_field_entry_python<'py>(
205222
&self,
206223
key: &Bound<'py, PyAny>,
@@ -235,10 +252,21 @@ impl GeneralFieldsSerializer {
235252
if let Some(field) = op_field {
236253
if let Some(ref serializer) = field.serializer {
237254
if !exclude_default(value, &field_extra, serializer)? {
238-
let value =
239-
serializer.to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
255+
let processed_value = if matches!(extra.sort_keys, SortKeysMode::Recursive)
256+
&& value.downcast::<PyDict>().is_ok()
257+
{
258+
let sorted_dict = Self::sort_dict_recursive(value.py(), value)?;
259+
serializer.to_python(
260+
sorted_dict.as_ref(),
261+
next_include.as_ref(),
262+
next_exclude.as_ref(),
263+
&field_extra,
264+
)?
265+
} else {
266+
serializer.to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
267+
};
240268
let output_key = field.get_key_py(output_dict.py(), &field_extra);
241-
output_dict.set_item(output_key, value)?;
269+
output_dict.set_item(output_key, processed_value)?;
242270
}
243271
}
244272

@@ -247,13 +275,33 @@ impl GeneralFieldsSerializer {
247275
}
248276
return Ok(Some(false));
249277
} else if self.mode == FieldsMode::TypedDictAllow {
250-
let value = match &self.extra_serializer {
251-
Some(serializer) => {
252-
serializer.to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
278+
let processed_value = if matches!(extra.sort_keys, SortKeysMode::Recursive)
279+
&& value.downcast::<PyDict>().is_ok()
280+
{
281+
let sorted_dict = Self::sort_dict_recursive(value.py(), value)?;
282+
match &self.extra_serializer {
283+
Some(serializer) => serializer.to_python(
284+
sorted_dict.as_ref(),
285+
next_include.as_ref(),
286+
next_exclude.as_ref(),
287+
&field_extra,
288+
)?,
289+
None => infer_to_python(
290+
sorted_dict.as_ref(),
291+
next_include.as_ref(),
292+
next_exclude.as_ref(),
293+
&field_extra,
294+
)?,
295+
}
296+
} else {
297+
match &self.extra_serializer {
298+
Some(serializer) => {
299+
serializer.to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
300+
}
301+
None => infer_to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
253302
}
254-
None => infer_to_python(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
255303
};
256-
output_dict.set_item(key, value)?;
304+
output_dict.set_item(key, processed_value)?;
257305
return Ok(None);
258306
} else if field_extra.check == SerCheck::Strict {
259307
return Err(PydanticSerializationUnexpectedValue::new(
@@ -281,7 +329,7 @@ impl GeneralFieldsSerializer {
281329
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
282330
let mut map = serializer.serialize_map(Some(expected_len))?;
283331

284-
if !extra.sort_keys {
332+
if matches!(extra.sort_keys, SortKeysMode::Unsorted) {
285333
for result in main_iter {
286334
let (key, value) = result.map_err(py_err_se_err)?;
287335
self.process_field_entry::<S>(&key, &value, &mut map, include, exclude, &extra)?;
@@ -308,32 +356,56 @@ impl GeneralFieldsSerializer {
308356
if extra.exclude_none && value.is_none() {
309357
return Ok(());
310358
}
311-
let key_str = key_str(key).map_err(py_err_se_err)?;
359+
let field_key_str = key_str(key).map_err(py_err_se_err)?;
312360
let field_extra = Extra {
313-
field_name: Some(key_str),
361+
field_name: Some(field_key_str),
314362
..*extra
315363
};
316364

317365
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
318366
if let Some((next_include, next_exclude)) = filter {
319-
if let Some(field) = self.fields.get(key_str) {
367+
if let Some(field) = self.fields.get(field_key_str) {
320368
if let Some(ref serializer) = field.serializer {
321369
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
322-
let s = PydanticSerializer::new(
323-
value,
324-
serializer,
325-
next_include.as_ref(),
326-
next_exclude.as_ref(),
327-
&field_extra,
328-
);
329-
let output_key = field.get_key_json(key_str, &field_extra);
330-
map.serialize_entry(&output_key, &s)?;
370+
if matches!(extra.sort_keys, SortKeysMode::Recursive) && value.downcast::<PyDict>().is_ok() {
371+
let sorted_dict = Self::sort_dict_recursive(value.py(), value).map_err(py_err_se_err)?;
372+
let s = PydanticSerializer::new(
373+
sorted_dict.as_ref(),
374+
serializer,
375+
next_include.as_ref(),
376+
next_exclude.as_ref(),
377+
&field_extra,
378+
);
379+
let output_key = field.get_key_json(field_key_str, &field_extra);
380+
map.serialize_entry(&output_key, &s)?;
381+
} else {
382+
let s = PydanticSerializer::new(
383+
value,
384+
serializer,
385+
next_include.as_ref(),
386+
next_exclude.as_ref(),
387+
&field_extra,
388+
);
389+
let output_key = field.get_key_json(field_key_str, &field_extra);
390+
map.serialize_entry(&output_key, &s)?;
391+
}
331392
}
332393
}
333394
} else if self.mode == FieldsMode::TypedDictAllow {
334395
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
335-
let s = SerializeInfer::new(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
336-
map.serialize_entry(&output_key, &s)?;
396+
if matches!(extra.sort_keys, SortKeysMode::Recursive) && value.downcast::<PyDict>().is_ok() {
397+
let sorted_dict = Self::sort_dict_recursive(value.py(), value).map_err(py_err_se_err)?;
398+
let s = SerializeInfer::new(
399+
sorted_dict.as_ref(),
400+
next_include.as_ref(),
401+
next_exclude.as_ref(),
402+
&field_extra,
403+
);
404+
map.serialize_entry(&output_key, &s)?;
405+
} else {
406+
let s = SerializeInfer::new(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
407+
map.serialize_entry(&output_key, &s)?;
408+
}
337409
}
338410
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
339411
}

0 commit comments

Comments
 (0)