diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 0fcfe9cab..2cb040968 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2817,7 +2817,8 @@ class TypedDictField(TypedDict, total=False): validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False - metadata: Dict[str, Any] + exclude_if: Callable[[Any], bool] # default None + metadata: Any def typed_dict_field( @@ -2827,7 +2828,8 @@ def typed_dict_field( validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, serialization_exclude: bool | None = None, - metadata: Dict[str, Any] | None = None, + exclude_if: Callable[[Any], bool] | None = None, + metadata: Any = None, ) -> TypedDictField: """ Returns a schema that matches a typed dict field, e.g.: @@ -2844,6 +2846,7 @@ def typed_dict_field( validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing serialization_exclude: Whether to exclude the field when serializing + exclude_if: Callable that determines whether to exclude a field during serialization based on its value. metadata: Any other information you want to include with the schema, not used by pydantic-core """ return _dict_not_none( @@ -2853,6 +2856,7 @@ def typed_dict_field( validation_alias=validation_alias, serialization_alias=serialization_alias, serialization_exclude=serialization_exclude, + exclude_if=exclude_if, metadata=metadata, ) @@ -2943,6 +2947,7 @@ class ModelField(TypedDict, total=False): validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False + exclude_if: Callable[[Any], bool] # default: None frozen: bool metadata: Dict[str, Any] @@ -2953,6 +2958,7 @@ def model_field( validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, serialization_exclude: bool | None = None, + exclude_if: Callable[[Any], bool] | None = None, frozen: bool | None = None, metadata: Dict[str, Any] | None = None, ) -> ModelField: @@ -2970,6 +2976,7 @@ def model_field( validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing serialization_exclude: Whether to exclude the field when serializing + exclude_if: Callable that determines whether to exclude a field during serialization based on its value. frozen: Whether the field is frozen metadata: Any other information you want to include with the schema, not used by pydantic-core """ @@ -2979,6 +2986,7 @@ def model_field( validation_alias=validation_alias, serialization_alias=serialization_alias, serialization_exclude=serialization_exclude, + exclude_if=exclude_if, frozen=frozen, metadata=metadata, ) @@ -3171,7 +3179,8 @@ class DataclassField(TypedDict, total=False): validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False - metadata: Dict[str, Any] + exclude_if: Callable[[Any], bool] # default: None + metadata: Any def dataclass_field( @@ -3184,7 +3193,8 @@ def dataclass_field( validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, serialization_exclude: bool | None = None, - metadata: Dict[str, Any] | None = None, + exclude_if: Callable[[Any], bool] | None = None, + metadata: Any = None, frozen: bool | None = None, ) -> DataclassField: """ @@ -3210,6 +3220,7 @@ def dataclass_field( validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing serialization_exclude: Whether to exclude the field when serializing + exclude_if: Callable that determines whether to exclude a field during serialization based on its value. metadata: Any other information you want to include with the schema, not used by pydantic-core frozen: Whether the field is frozen """ @@ -3223,6 +3234,7 @@ def dataclass_field( validation_alias=validation_alias, serialization_alias=serialization_alias, serialization_exclude=serialization_exclude, + exclude_if=exclude_if, metadata=metadata, frozen=frozen, ) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index 4498d8fa7..2a38b593d 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -29,6 +29,7 @@ pub(super) struct SerField { // None serializer means exclude pub serializer: Option, pub required: bool, + pub exclude_if: Option>, } impl_py_gc_traverse!(SerField { serializer }); @@ -40,6 +41,7 @@ impl SerField { alias: Option, serializer: Option, required: bool, + exclude_if: Option>, ) -> Self { let alias_py = alias .as_ref() @@ -50,6 +52,7 @@ impl SerField { alias_py, serializer, required, + exclude_if, } } @@ -72,6 +75,18 @@ impl SerField { } } +fn exclude_if(exclude_if_callable: &Option>, value: &Bound<'_, PyAny>) -> PyResult { + if let Some(exclude_if_callable) = exclude_if_callable { + let py = value.py(); + let result = exclude_if_callable.call1(py, (value,))?; + let exclude = result.extract::(py)?; + if exclude { + return Ok(true); + } + } + Ok(false) +} + fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult { if extra.exclude_defaults { if let Some(default) = serializer.get_default(value.py())? { @@ -80,6 +95,7 @@ fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &Combine } } } + // If neither condition is met, do not exclude the field Ok(false) } @@ -176,16 +192,16 @@ impl GeneralFieldsSerializer { if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? { if let Some(field) = op_field { if let Some(ref serializer) = field.serializer { - if !exclude_default(&value, &field_extra, serializer)? { - let value = serializer.to_python( - &value, - next_include.as_ref(), - next_exclude.as_ref(), - &field_extra, - )?; - let output_key = field.get_key_py(output_dict.py(), &field_extra); - output_dict.set_item(output_key, value)?; + if exclude_default(&value, &field_extra, serializer)? { + continue; } + if exclude_if(&field.exclude_if, &value)? { + continue; + } + let value = + serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?; + let output_key = field.get_key_py(output_dict.py(), &field_extra); + output_dict.set_item(output_key, value)?; } if field.required { @@ -263,17 +279,21 @@ impl GeneralFieldsSerializer { if let Some((next_include, next_exclude)) = filter { if let Some(field) = self.fields.get(key_str) { if let Some(ref serializer) = field.serializer { - if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? { - let s = PydanticSerializer::new( - &value, - serializer, - next_include.as_ref(), - next_exclude.as_ref(), - &field_extra, - ); - let output_key = field.get_key_json(key_str, &field_extra); - map.serialize_entry(&output_key, &s)?; + if exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? { + continue; + } + if exclude_if(&field.exclude_if, &value).map_err(py_err_se_err)? { + continue; } + let s = PydanticSerializer::new( + &value, + serializer, + next_include.as_ref(), + next_exclude.as_ref(), + &field_extra, + ); + let output_key = field.get_key_json(key_str, &field_extra); + map.serialize_entry(&output_key, &s)?; } } else if self.mode == FieldsMode::TypedDictAllow { let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?; diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index ffe71adb9..24004cce9 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -44,14 +44,18 @@ impl BuildSerializer for DataclassArgsBuilder { let key_py: Py = PyString::new_bound(py, &name).into(); if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) { - fields.insert(name, SerField::new(py, key_py, None, None, true)); + fields.insert(name, SerField::new(py, key_py, None, None, true, None)); } else { let schema = field_info.get_as_req(intern!(py, "schema"))?; let serializer = CombinedSerializer::build(&schema, config, definitions) .map_err(|e| py_schema_error_type!("Field `{}`:\n {}", index, e))?; let alias = field_info.get_as(intern!(py, "serialization_alias"))?; - fields.insert(name, SerField::new(py, key_py, alias, Some(serializer), true)); + let exclude_if: Option> = field_info.get_as(intern!(py, "exclude_if"))?; + fields.insert( + name, + SerField::new(py, key_py, alias, Some(serializer), true, exclude_if), + ); } } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 36ddaf69f..369b4bbdf 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -54,15 +54,18 @@ impl BuildSerializer for ModelFieldsBuilder { let key_py: Py = key_py.into(); if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) { - fields.insert(key, SerField::new(py, key_py, None, None, true)); + fields.insert(key, SerField::new(py, key_py, None, None, true, None)); } else { let alias: Option = field_info.get_as(intern!(py, "serialization_alias"))?; - + let exclude_if: Option> = field_info.get_as(intern!(py, "exclude_if"))?; let schema = field_info.get_as_req(intern!(py, "schema"))?; let serializer = CombinedSerializer::build(&schema, config, definitions) .map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?; - fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), true)); + fields.insert( + key, + SerField::new(py, key_py, alias, Some(serializer), true, exclude_if), + ); } } diff --git a/src/serializers/type_serializers/typed_dict.rs b/src/serializers/type_serializers/typed_dict.rs index e80a9e9b3..94aafb7df 100644 --- a/src/serializers/type_serializers/typed_dict.rs +++ b/src/serializers/type_serializers/typed_dict.rs @@ -52,14 +52,17 @@ impl BuildSerializer for TypedDictBuilder { let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total); if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) { - fields.insert(key, SerField::new(py, key_py, None, None, required)); + fields.insert(key, SerField::new(py, key_py, None, None, required, None)); } else { let alias: Option = field_info.get_as(intern!(py, "serialization_alias"))?; - + let exclude_if: Option> = field_info.get_as(intern!(py, "exclude_if"))?; let schema = field_info.get_as_req(intern!(py, "schema"))?; let serializer = CombinedSerializer::build(&schema, config, definitions) .map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?; - fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), required)); + fields.insert( + key, + SerField::new(py, key_py, alias, Some(serializer), required, exclude_if), + ); } } diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index eb4bede97..2057227bf 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -54,7 +54,7 @@ def test_serialization_exclude(): core_schema.dataclass_args_schema( 'Foo', [ - core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='a', schema=core_schema.str_schema(), exclude_if=lambda x: x == 'bye'), core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True), ], ), @@ -63,12 +63,18 @@ def test_serialization_exclude(): s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'} assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == {'a': 'hello'} + # a = 'bye' excludes it + assert s.to_python(Foo(a='bye', b=b'more'), mode='json') == {} j = s.to_json(Foo(a='hello', b=b'more')) - if on_pypy: assert json.loads(j) == {'a': 'hello'} else: assert j == b'{"a":"hello"}' + j = s.to_json(Foo(a='bye', b=b'more')) + if on_pypy: + assert json.loads(j) == {} + else: + assert j == b'{}' def test_serialization_alias(): diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index a151b7454..893ec046a 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -511,7 +511,9 @@ def __init__(self, **kwargs): MyModel, core_schema.typed_dict_schema( { - 'a': core_schema.typed_dict_field(core_schema.any_schema()), + 'a': core_schema.typed_dict_field( + core_schema.any_schema(), exclude_if=lambda x: isinstance(x, int) and x >= 2 + ), 'b': core_schema.typed_dict_field(core_schema.any_schema()), 'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True), } @@ -535,6 +537,14 @@ def __init__(self, **kwargs): assert s.to_json(m, exclude={'b'}) == b'{"a":1}' assert calls == 6 + m = MyModel(a=2, b=b'foobar', c='excluded') + assert s.to_python(m) == {'b': b'foobar'} + assert calls == 7 + assert s.to_python(m, mode='json') == {'b': 'foobar'} + assert calls == 8 + assert s.to_json(m) == b'{"b":"foobar"}' + assert calls == 9 + def test_function_plain_model(): calls = 0 @@ -553,7 +563,7 @@ def __init__(self, **kwargs): MyModel, core_schema.typed_dict_schema( { - 'a': core_schema.typed_dict_field(core_schema.any_schema()), + 'a': core_schema.typed_dict_field(core_schema.any_schema(), exclude_if=lambda x: x == 100), 'b': core_schema.typed_dict_field(core_schema.any_schema()), 'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True), } diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 9fa44032a..486741e6d 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -203,6 +203,32 @@ def test_include_exclude_args(params): assert json.loads(s.to_json(value, include=include, exclude=exclude)) == expected +def test_exclude_if(): + s = SchemaSerializer( + core_schema.model_schema( + BasicModel, + core_schema.model_fields_schema( + { + 'a': core_schema.model_field(core_schema.int_schema(), exclude_if=lambda x: x > 1), + 'b': core_schema.model_field(core_schema.str_schema(), exclude_if=lambda x: 'foo' in x), + 'c': core_schema.model_field( + core_schema.str_schema(), serialization_exclude=True, exclude_if=lambda x: 'foo' in x + ), + } + ), + ) + ) + assert s.to_python(BasicModel(a=0, b='bar', c='bar')) == {'a': 0, 'b': 'bar'} + assert s.to_python(BasicModel(a=2, b='bar', c='bar')) == {'b': 'bar'} + assert s.to_python(BasicModel(a=0, b='foo', c='bar')) == {'a': 0} + assert s.to_python(BasicModel(a=2, b='foo', c='bar')) == {} + + assert s.to_json(BasicModel(a=0, b='bar', c='bar')) == b'{"a":0,"b":"bar"}' + assert s.to_json(BasicModel(a=2, b='bar', c='bar')) == b'{"b":"bar"}' + assert s.to_json(BasicModel(a=0, b='foo', c='bar')) == b'{"a":0}' + assert s.to_json(BasicModel(a=2, b='foo', c='bar')) == b'{}' + + def test_alias(): s = SchemaSerializer( core_schema.model_schema( diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index df507a248..be68922bb 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -92,8 +92,12 @@ def test_include_exclude_schema(): { '0': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True), '1': core_schema.typed_dict_field(core_schema.int_schema()), - '2': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True), - '3': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=False), + '2': core_schema.typed_dict_field( + core_schema.int_schema(), serialization_exclude=True, exclude_if=lambda x: x < 0 + ), + '3': core_schema.typed_dict_field( + core_schema.int_schema(), serialization_exclude=False, exclude_if=lambda x: x < 0 + ), } ) ) @@ -102,6 +106,11 @@ def test_include_exclude_schema(): assert s.to_python(value, mode='json') == {'1': 1, '3': 3} assert json.loads(s.to_json(value)) == {'1': 1, '3': 3} + value = {'0': 0, '1': 1, '2': 2, '3': -3} + assert s.to_python(value) == {'1': 1} + assert s.to_python(value, mode='json') == {'1': 1} + assert json.loads(s.to_json(value)) == {'1': 1} + def test_alias(): s = SchemaSerializer(