Skip to content

Commit d7d7d0e

Browse files
committed
Support exclude_if callable at field level
1 parent 2419981 commit d7d7d0e

File tree

8 files changed

+90
-22
lines changed

8 files changed

+90
-22
lines changed

python/pydantic_core/core_schema.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,8 @@ class TypedDictField(TypedDict, total=False):
28172817
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
28182818
serialization_alias: str
28192819
serialization_exclude: bool # default: False
2820-
metadata: Dict[str, Any]
2820+
exclude_if: Callable[[Any], bool] # default None
2821+
metadata: Any
28212822

28222823

28232824
def typed_dict_field(
@@ -2827,7 +2828,8 @@ def typed_dict_field(
28272828
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
28282829
serialization_alias: str | None = None,
28292830
serialization_exclude: bool | None = None,
2830-
metadata: Dict[str, Any] | None = None,
2831+
exclude_if: Callable[[Any], bool] | None = None,
2832+
metadata: Any = None,
28312833
) -> TypedDictField:
28322834
"""
28332835
Returns a schema that matches a typed dict field, e.g.:
@@ -2844,6 +2846,7 @@ def typed_dict_field(
28442846
validation_alias: The alias(es) to use to find the field in the validation data
28452847
serialization_alias: The alias to use as a key when serializing
28462848
serialization_exclude: Whether to exclude the field when serializing
2849+
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
28472850
metadata: Any other information you want to include with the schema, not used by pydantic-core
28482851
"""
28492852
return _dict_not_none(
@@ -2853,6 +2856,7 @@ def typed_dict_field(
28532856
validation_alias=validation_alias,
28542857
serialization_alias=serialization_alias,
28552858
serialization_exclude=serialization_exclude,
2859+
exclude_if=exclude_if,
28562860
metadata=metadata,
28572861
)
28582862

@@ -2943,6 +2947,7 @@ class ModelField(TypedDict, total=False):
29432947
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
29442948
serialization_alias: str
29452949
serialization_exclude: bool # default: False
2950+
exclude_if: Callable[[Any], bool] # default: None
29462951
frozen: bool
29472952
metadata: Dict[str, Any]
29482953

@@ -2953,6 +2958,7 @@ def model_field(
29532958
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
29542959
serialization_alias: str | None = None,
29552960
serialization_exclude: bool | None = None,
2961+
exclude_if: Callable[[Any], bool] | None = None,
29562962
frozen: bool | None = None,
29572963
metadata: Dict[str, Any] | None = None,
29582964
) -> ModelField:
@@ -2970,6 +2976,7 @@ def model_field(
29702976
validation_alias: The alias(es) to use to find the field in the validation data
29712977
serialization_alias: The alias to use as a key when serializing
29722978
serialization_exclude: Whether to exclude the field when serializing
2979+
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
29732980
frozen: Whether the field is frozen
29742981
metadata: Any other information you want to include with the schema, not used by pydantic-core
29752982
"""
@@ -2979,6 +2986,7 @@ def model_field(
29792986
validation_alias=validation_alias,
29802987
serialization_alias=serialization_alias,
29812988
serialization_exclude=serialization_exclude,
2989+
exclude_if=exclude_if,
29822990
frozen=frozen,
29832991
metadata=metadata,
29842992
)
@@ -3171,7 +3179,8 @@ class DataclassField(TypedDict, total=False):
31713179
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
31723180
serialization_alias: str
31733181
serialization_exclude: bool # default: False
3174-
metadata: Dict[str, Any]
3182+
exclude_if: Callable[[Any], bool] # default: None
3183+
metadata: Any
31753184

31763185

31773186
def dataclass_field(
@@ -3184,7 +3193,8 @@ def dataclass_field(
31843193
validation_alias: str | list[str | int] | list[list[str | int]] | None = None,
31853194
serialization_alias: str | None = None,
31863195
serialization_exclude: bool | None = None,
3187-
metadata: Dict[str, Any] | None = None,
3196+
exclude_if: Callable[[Any], bool] | None = None,
3197+
metadata: Any = None,
31883198
frozen: bool | None = None,
31893199
) -> DataclassField:
31903200
"""
@@ -3210,6 +3220,7 @@ def dataclass_field(
32103220
validation_alias: The alias(es) to use to find the field in the validation data
32113221
serialization_alias: The alias to use as a key when serializing
32123222
serialization_exclude: Whether to exclude the field when serializing
3223+
exclude_if: Callable that determines whether to exclude a field during serialization based on its value.
32133224
metadata: Any other information you want to include with the schema, not used by pydantic-core
32143225
frozen: Whether the field is frozen
32153226
"""
@@ -3223,6 +3234,7 @@ def dataclass_field(
32233234
validation_alias=validation_alias,
32243235
serialization_alias=serialization_alias,
32253236
serialization_exclude=serialization_exclude,
3237+
exclude_if=exclude_if,
32263238
metadata=metadata,
32273239
frozen=frozen,
32283240
)

src/serializers/fields.rs

+25-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(super) struct SerField {
2929
// None serializer means exclude
3030
pub serializer: Option<CombinedSerializer>,
3131
pub required: bool,
32+
pub exclude_if: Option<Py<PyAny>>,
3233
}
3334

3435
impl_py_gc_traverse!(SerField { serializer });
@@ -40,6 +41,7 @@ impl SerField {
4041
alias: Option<String>,
4142
serializer: Option<CombinedSerializer>,
4243
required: bool,
44+
exclude_if: Option<Py<PyAny>>,
4345
) -> Self {
4446
let alias_py = alias
4547
.as_ref()
@@ -50,6 +52,7 @@ impl SerField {
5052
alias_py,
5153
serializer,
5254
required,
55+
exclude_if,
5356
}
5457
}
5558

@@ -72,14 +75,30 @@ impl SerField {
7275
}
7376
}
7477

75-
fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
78+
fn exclude_default_or_if(
79+
exclude_if_callable: &Option<Py<PyAny>>,
80+
value: &Bound<'_, PyAny>,
81+
extra: &Extra,
82+
serializer: &CombinedSerializer,
83+
) -> PyResult<bool> {
84+
let py = value.py();
85+
86+
if let Some(exclude_if_callable) = exclude_if_callable {
87+
let result = exclude_if_callable.call1(py, (value,))?;
88+
let exclude = result.extract::<bool>(py)?;
89+
if exclude {
90+
return Ok(true);
91+
}
92+
}
93+
7694
if extra.exclude_defaults {
77-
if let Some(default) = serializer.get_default(value.py())? {
95+
if let Some(default) = serializer.get_default(py)? {
7896
if value.eq(default)? {
7997
return Ok(true);
8098
}
8199
}
82100
}
101+
// If neither condition is met, do not exclude the field
83102
Ok(false)
84103
}
85104

@@ -176,7 +195,7 @@ impl GeneralFieldsSerializer {
176195
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
177196
if let Some(field) = op_field {
178197
if let Some(ref serializer) = field.serializer {
179-
if !exclude_default(&value, &field_extra, serializer)? {
198+
if !exclude_default_or_if(&field.exclude_if, &value, &field_extra, serializer)? {
180199
let value = serializer.to_python(
181200
&value,
182201
next_include.as_ref(),
@@ -262,7 +281,9 @@ impl GeneralFieldsSerializer {
262281
if let Some((next_include, next_exclude)) = filter {
263282
if let Some(field) = self.fields.get(key_str) {
264283
if let Some(ref serializer) = field.serializer {
265-
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
284+
if !exclude_default_or_if(&field.exclude_if, &value, &field_extra, serializer)
285+
.map_err(py_err_se_err)?
286+
{
266287
let s = PydanticSerializer::new(
267288
&value,
268289
serializer,

src/serializers/type_serializers/dataclass.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ impl BuildSerializer for DataclassArgsBuilder {
4444
let key_py: Py<PyString> = PyString::new_bound(py, &name).into();
4545

4646
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
47-
fields.insert(name, SerField::new(py, key_py, None, None, true));
47+
fields.insert(name, SerField::new(py, key_py, None, None, true, None));
4848
} else {
4949
let schema = field_info.get_as_req(intern!(py, "schema"))?;
5050
let serializer = CombinedSerializer::build(&schema, config, definitions)
5151
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", index, e))?;
5252

5353
let alias = field_info.get_as(intern!(py, "serialization_alias"))?;
54-
fields.insert(name, SerField::new(py, key_py, alias, Some(serializer), true));
54+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
55+
fields.insert(
56+
name,
57+
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
58+
);
5559
}
5660
}
5761

src/serializers/type_serializers/model.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ impl BuildSerializer for ModelFieldsBuilder {
5454
let key_py: Py<PyString> = key_py.into();
5555

5656
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
57-
fields.insert(key, SerField::new(py, key_py, None, None, true));
57+
fields.insert(key, SerField::new(py, key_py, None, None, true, None));
5858
} else {
5959
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
60-
60+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
6161
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6262
let serializer = CombinedSerializer::build(&schema, config, definitions)
6363
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
6464

65-
fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), true));
65+
fields.insert(
66+
key,
67+
SerField::new(py, key_py, alias, Some(serializer), true, exclude_if),
68+
);
6669
}
6770
}
6871

src/serializers/type_serializers/typed_dict.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@ impl BuildSerializer for TypedDictBuilder {
5252
let required = field_info.get_as(intern!(py, "required"))?.unwrap_or(total);
5353

5454
if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) {
55-
fields.insert(key, SerField::new(py, key_py, None, None, required));
55+
fields.insert(key, SerField::new(py, key_py, None, None, required, None));
5656
} else {
5757
let alias: Option<String> = field_info.get_as(intern!(py, "serialization_alias"))?;
58-
58+
let exclude_if: Option<Py<PyAny>> = field_info.get_as(intern!(py, "exclude_if"))?;
5959
let schema = field_info.get_as_req(intern!(py, "schema"))?;
6060
let serializer = CombinedSerializer::build(&schema, config, definitions)
6161
.map_err(|e| py_schema_error_type!("Field `{}`:\n {}", key, e))?;
62-
fields.insert(key, SerField::new(py, key_py, alias, Some(serializer), required));
62+
fields.insert(
63+
key,
64+
SerField::new(py, key_py, alias, Some(serializer), required, exclude_if),
65+
);
6366
}
6467
}
6568

tests/serializers/test_dataclasses.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_serialization_exclude():
5454
core_schema.dataclass_args_schema(
5555
'Foo',
5656
[
57-
core_schema.dataclass_field(name='a', schema=core_schema.str_schema()),
57+
core_schema.dataclass_field(name='a', schema=core_schema.str_schema(), exclude_if=lambda x: x == 'bye'),
5858
core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True),
5959
],
6060
),
@@ -63,12 +63,18 @@ def test_serialization_exclude():
6363
s = SchemaSerializer(schema)
6464
assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'}
6565
assert s.to_python(Foo(a='hello', b=b'more'), mode='json') == {'a': 'hello'}
66+
# a = 'bye' excludes it
67+
assert s.to_python(Foo(a='bye', b=b'more'), mode='json') == {}
6668
j = s.to_json(Foo(a='hello', b=b'more'))
67-
6869
if on_pypy:
6970
assert json.loads(j) == {'a': 'hello'}
7071
else:
7172
assert j == b'{"a":"hello"}'
73+
j = s.to_json(Foo(a='bye', b=b'more'))
74+
if on_pypy:
75+
assert json.loads(j) == {}
76+
else:
77+
assert j == b'{}'
7278

7379

7480
def test_serialization_alias():

tests/serializers/test_functions.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ def __init__(self, **kwargs):
511511
MyModel,
512512
core_schema.typed_dict_schema(
513513
{
514-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
514+
'a': core_schema.typed_dict_field(
515+
core_schema.any_schema(), exclude_if=lambda x: isinstance(x, int) and x >= 2
516+
),
515517
'b': core_schema.typed_dict_field(core_schema.any_schema()),
516518
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
517519
}
@@ -535,6 +537,14 @@ def __init__(self, **kwargs):
535537
assert s.to_json(m, exclude={'b'}) == b'{"a":1}'
536538
assert calls == 6
537539

540+
m = MyModel(a=2, b=b'foobar', c='excluded')
541+
assert s.to_python(m) == {'b': b'foobar'}
542+
assert calls == 7
543+
assert s.to_python(m, mode='json') == {'b': 'foobar'}
544+
assert calls == 8
545+
assert s.to_json(m) == b'{"b":"foobar"}'
546+
assert calls == 9
547+
538548

539549
def test_function_plain_model():
540550
calls = 0
@@ -553,7 +563,7 @@ def __init__(self, **kwargs):
553563
MyModel,
554564
core_schema.typed_dict_schema(
555565
{
556-
'a': core_schema.typed_dict_field(core_schema.any_schema()),
566+
'a': core_schema.typed_dict_field(core_schema.any_schema(), exclude_if=lambda x: x == 100),
557567
'b': core_schema.typed_dict_field(core_schema.any_schema()),
558568
'c': core_schema.typed_dict_field(core_schema.any_schema(), serialization_exclude=True),
559569
}

tests/serializers/test_typed_dict.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def test_include_exclude_schema():
9292
{
9393
'0': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True),
9494
'1': core_schema.typed_dict_field(core_schema.int_schema()),
95-
'2': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=True),
96-
'3': core_schema.typed_dict_field(core_schema.int_schema(), serialization_exclude=False),
95+
'2': core_schema.typed_dict_field(
96+
core_schema.int_schema(), serialization_exclude=True, exclude_if=lambda x: x < 0
97+
),
98+
'3': core_schema.typed_dict_field(
99+
core_schema.int_schema(), serialization_exclude=False, exclude_if=lambda x: x < 0
100+
),
97101
}
98102
)
99103
)
@@ -102,6 +106,11 @@ def test_include_exclude_schema():
102106
assert s.to_python(value, mode='json') == {'1': 1, '3': 3}
103107
assert json.loads(s.to_json(value)) == {'1': 1, '3': 3}
104108

109+
value = {'0': 0, '1': 1, '2': 2, '3': -3}
110+
assert s.to_python(value) == {'1': 1}
111+
assert s.to_python(value, mode='json') == {'1': 1}
112+
assert json.loads(s.to_json(value)) == {'1': 1}
113+
105114

106115
def test_alias():
107116
s = SchemaSerializer(

0 commit comments

Comments
 (0)