Skip to content

Commit 164b9ff

Browse files
Memory usage optimization via reuse of SchemaValidator and SchemaSerializer (#1616)
1 parent 3707dcd commit 164b9ff

File tree

8 files changed

+224
-2
lines changed

8 files changed

+224
-2
lines changed

src/common/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub(crate) mod prebuilt;
12
pub(crate) mod union;

src/common/prebuilt.rs

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use pyo3::intern;
2+
use pyo3::prelude::*;
3+
use pyo3::types::{PyAny, PyDict, PyType};
4+
5+
use crate::tools::SchemaDict;
6+
7+
pub fn get_prebuilt<T>(
8+
type_: &str,
9+
schema: &Bound<'_, PyDict>,
10+
prebuilt_attr_name: &str,
11+
extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult<T>,
12+
) -> PyResult<Option<T>> {
13+
let py = schema.py();
14+
15+
// we can only use prebuilt validators / serializers from models, typed dicts, and dataclasses
16+
// however, we don't want to use a prebuilt structure from dataclasses if we have a generic_origin
17+
// because the validator / serializer is cached on the unparametrized dataclass
18+
if !matches!(type_, "model" | "typed-dict")
19+
|| matches!(type_, "dataclass") && schema.contains(intern!(py, "generic_origin"))?
20+
{
21+
return Ok(None);
22+
}
23+
24+
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
25+
26+
// Note: we NEED to use the __dict__ here (and perform get_item calls rather than getattr)
27+
// because we don't want to fetch prebuilt validators from parent classes.
28+
// We don't downcast here because __dict__ on a class is a readonly mappingproxy,
29+
// so we can just leave it as is and do get_item checks.
30+
let class_dict = class.getattr(intern!(py, "__dict__"))?;
31+
32+
let is_complete: bool = class_dict
33+
.get_item(intern!(py, "__pydantic_complete__"))
34+
.is_ok_and(|b| b.extract().unwrap_or(false));
35+
36+
if !is_complete {
37+
return Ok(None);
38+
}
39+
40+
// Retrieve the prebuilt validator / serializer if available
41+
let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?;
42+
extractor(prebuilt).map(Some)
43+
}

src/serializers/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ mod fields;
2424
mod filter;
2525
mod infer;
2626
mod ob_type;
27+
mod prebuilt;
2728
pub mod ser;
2829
mod shared;
2930
mod type_serializers;

src/serializers/prebuilt.rs

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use std::borrow::Cow;
2+
3+
use pyo3::prelude::*;
4+
use pyo3::types::PyDict;
5+
6+
use crate::common::prebuilt::get_prebuilt;
7+
use crate::SchemaSerializer;
8+
9+
use super::extra::Extra;
10+
use super::shared::{CombinedSerializer, TypeSerializer};
11+
12+
#[derive(Debug)]
13+
pub struct PrebuiltSerializer {
14+
schema_serializer: Py<SchemaSerializer>,
15+
}
16+
17+
impl PrebuiltSerializer {
18+
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedSerializer>> {
19+
get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| {
20+
py_any
21+
.extract::<Py<SchemaSerializer>>()
22+
.map(|schema_serializer| Self { schema_serializer }.into())
23+
})
24+
}
25+
}
26+
27+
impl_py_gc_traverse!(PrebuiltSerializer { schema_serializer });
28+
29+
impl TypeSerializer for PrebuiltSerializer {
30+
fn to_python(
31+
&self,
32+
value: &Bound<'_, PyAny>,
33+
include: Option<&Bound<'_, PyAny>>,
34+
exclude: Option<&Bound<'_, PyAny>>,
35+
extra: &Extra,
36+
) -> PyResult<PyObject> {
37+
self.schema_serializer
38+
.get()
39+
.serializer
40+
.to_python(value, include, exclude, extra)
41+
}
42+
43+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
44+
self.schema_serializer.get().serializer.json_key(key, extra)
45+
}
46+
47+
fn serde_serialize<S: serde::ser::Serializer>(
48+
&self,
49+
value: &Bound<'_, PyAny>,
50+
serializer: S,
51+
include: Option<&Bound<'_, PyAny>>,
52+
exclude: Option<&Bound<'_, PyAny>>,
53+
extra: &Extra,
54+
) -> Result<S::Ok, S::Error> {
55+
self.schema_serializer
56+
.get()
57+
.serializer
58+
.serde_serialize(value, serializer, include, exclude, extra)
59+
}
60+
61+
fn get_name(&self) -> &str {
62+
self.schema_serializer.get().serializer.get_name()
63+
}
64+
65+
fn retry_with_lax_check(&self) -> bool {
66+
self.schema_serializer.get().serializer.retry_with_lax_check()
67+
}
68+
}

src/serializers/shared.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ combined_serializer! {
8484
Function: super::type_serializers::function::FunctionPlainSerializer;
8585
FunctionWrap: super::type_serializers::function::FunctionWrapSerializer;
8686
Fields: super::fields::GeneralFieldsSerializer;
87+
// prebuilt serializers are manually constructed, and thus manually added to the `CombinedSerializer` enum
88+
Prebuilt: super::prebuilt::PrebuiltSerializer;
8789
}
8890
// `find_only` is for type_serializers which are built directly via the `type` key and `find_serializer`
8991
// but aren't actually used for serialization, e.g. their `build` method must return another serializer
@@ -195,7 +197,14 @@ impl CombinedSerializer {
195197
}
196198

197199
let type_: Bound<'_, PyString> = schema.get_as_req(type_key)?;
198-
Self::find_serializer(type_.to_str()?, schema, config, definitions)
200+
let type_ = type_.to_str()?;
201+
202+
// if we have a SchemaValidator on the type already, use it
203+
if let Ok(Some(prebuilt_serializer)) = super::prebuilt::PrebuiltSerializer::try_get_from_schema(type_, schema) {
204+
return Ok(prebuilt_serializer);
205+
}
206+
207+
Self::find_serializer(type_, schema, config, definitions)
199208
}
200209
}
201210

@@ -219,6 +228,7 @@ impl PyGcTraverse for CombinedSerializer {
219228
CombinedSerializer::Function(inner) => inner.py_gc_traverse(visit),
220229
CombinedSerializer::FunctionWrap(inner) => inner.py_gc_traverse(visit),
221230
CombinedSerializer::Fields(inner) => inner.py_gc_traverse(visit),
231+
CombinedSerializer::Prebuilt(inner) => inner.py_gc_traverse(visit),
222232
CombinedSerializer::None(inner) => inner.py_gc_traverse(visit),
223233
CombinedSerializer::Nullable(inner) => inner.py_gc_traverse(visit),
224234
CombinedSerializer::Int(inner) => inner.py_gc_traverse(visit),

src/validators/mod.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ mod model;
5252
mod model_fields;
5353
mod none;
5454
mod nullable;
55+
mod prebuilt;
5556
mod set;
5657
mod string;
5758
mod time;
@@ -515,8 +516,15 @@ pub fn build_validator(
515516
definitions: &mut DefinitionsBuilder<CombinedValidator>,
516517
) -> PyResult<CombinedValidator> {
517518
let dict = schema.downcast::<PyDict>()?;
518-
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(schema.py(), "type"))?;
519+
let py = schema.py();
520+
let type_: Bound<'_, PyString> = dict.get_as_req(intern!(py, "type"))?;
519521
let type_ = type_.to_str()?;
522+
523+
// if we have a SchemaValidator on the type already, use it
524+
if let Ok(Some(prebuilt_validator)) = prebuilt::PrebuiltValidator::try_get_from_schema(type_, dict) {
525+
return Ok(prebuilt_validator);
526+
}
527+
520528
validator_match!(
521529
type_,
522530
dict,
@@ -763,6 +771,8 @@ pub enum CombinedValidator {
763771
// input dependent
764772
JsonOrPython(json_or_python::JsonOrPython),
765773
Complex(complex::ComplexValidator),
774+
// uses a reference to an existing SchemaValidator to reduce memory usage
775+
Prebuilt(prebuilt::PrebuiltValidator),
766776
}
767777

768778
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

src/validators/prebuilt.rs

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use pyo3::prelude::*;
2+
use pyo3::types::PyDict;
3+
4+
use crate::common::prebuilt::get_prebuilt;
5+
use crate::errors::ValResult;
6+
use crate::input::Input;
7+
8+
use super::ValidationState;
9+
use super::{CombinedValidator, SchemaValidator, Validator};
10+
11+
#[derive(Debug)]
12+
pub struct PrebuiltValidator {
13+
schema_validator: Py<SchemaValidator>,
14+
}
15+
16+
impl PrebuiltValidator {
17+
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedValidator>> {
18+
get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| {
19+
py_any
20+
.extract::<Py<SchemaValidator>>()
21+
.map(|schema_validator| Self { schema_validator }.into())
22+
})
23+
}
24+
}
25+
26+
impl_py_gc_traverse!(PrebuiltValidator { schema_validator });
27+
28+
impl Validator for PrebuiltValidator {
29+
fn validate<'py>(
30+
&self,
31+
py: Python<'py>,
32+
input: &(impl Input<'py> + ?Sized),
33+
state: &mut ValidationState<'_, 'py>,
34+
) -> ValResult<PyObject> {
35+
self.schema_validator.get().validator.validate(py, input, state)
36+
}
37+
38+
fn get_name(&self) -> &str {
39+
self.schema_validator.get().validator.get_name()
40+
}
41+
}

tests/test_prebuilt.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from pydantic_core import SchemaSerializer, SchemaValidator, core_schema
2+
3+
4+
def test_prebuilt_val_and_ser_used() -> None:
5+
class InnerModel:
6+
x: int
7+
8+
inner_schema = core_schema.model_schema(
9+
InnerModel,
10+
schema=core_schema.model_fields_schema(
11+
{'x': core_schema.model_field(schema=core_schema.int_schema())},
12+
),
13+
)
14+
15+
inner_schema_validator = SchemaValidator(inner_schema)
16+
inner_schema_serializer = SchemaSerializer(inner_schema)
17+
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
18+
InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue]
19+
InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue]
20+
21+
class OuterModel:
22+
inner: InnerModel
23+
24+
outer_schema = core_schema.model_schema(
25+
OuterModel,
26+
schema=core_schema.model_fields_schema(
27+
{
28+
'inner': core_schema.model_field(
29+
schema=core_schema.model_schema(
30+
InnerModel,
31+
schema=core_schema.model_fields_schema(
32+
# note, we use str schema here even though that's incorrect
33+
# in order to verify that the prebuilt validator is used
34+
# off of InnerModel with the correct int schema, not this str schema
35+
{'x': core_schema.model_field(schema=core_schema.str_schema())},
36+
),
37+
)
38+
)
39+
}
40+
),
41+
)
42+
43+
outer_validator = SchemaValidator(outer_schema)
44+
outer_serializer = SchemaSerializer(outer_schema)
45+
46+
result = outer_validator.validate_python({'inner': {'x': 1}})
47+
assert result.inner.x == 1
48+
assert outer_serializer.to_python(result) == {'inner': {'x': 1}}

0 commit comments

Comments
 (0)