Skip to content

Support Thinking part #1142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
edffe70
Support Thinking part
Kludex Mar 16, 2025
1127c31
merge
Kludex Mar 27, 2025
ac7ca8d
Merge branch 'main' into support-thinking
Kludex Apr 1, 2025
bc84993
push
Kludex Apr 1, 2025
77a3338
merge
Kludex Apr 18, 2025
e5202cb
ignore
Kludex Apr 18, 2025
f4b7fde
Add more support for thinking part
Kludex Apr 18, 2025
f0da181
Add tests
Kludex Apr 19, 2025
3b92cc0
Add tests
Kludex Apr 19, 2025
985991b
pass tests
Kludex Apr 19, 2025
94abf96
Merge branch 'main' into support-thinking
Kludex Apr 19, 2025
0cb280e
fix pipeline
Kludex Apr 20, 2025
f3600f7
Implement streaming
Kludex Apr 20, 2025
427daf7
Merge branch 'main' into support-thinking
dmontagu Apr 21, 2025
b57502e
Minor cleanup
dmontagu Apr 21, 2025
03e9fd4
merge
Kludex Apr 22, 2025
a04533f
Support Thinking part
Kludex Apr 22, 2025
5fced6a
Support Thinking part
Kludex Apr 22, 2025
eaf70e1
Support Thinking part
Kludex Apr 22, 2025
aef5b47
Support Thinking part
Kludex Apr 22, 2025
4e92754
Support Thinking part
Kludex Apr 22, 2025
2373e39
Merge remote-tracking branch 'origin/main' into support-thinking
Kludex Apr 22, 2025
805088f
Add support for bedrock
Kludex Apr 22, 2025
d8c5861
add tests for bedrock
Kludex Apr 22, 2025
133abe3
Pass tests
Kludex Apr 22, 2025
36008af
fix test
Kludex Apr 22, 2025
155c373
pass tests
Kludex Apr 22, 2025
06fc39e
Merge branch 'main' into support-thinking
Kludex Apr 25, 2025
1a52cee
bump boto3
Kludex Apr 25, 2025
76a1d48
Coverage on Bedrock
Kludex Apr 25, 2025
7b56087
Bump openai
Kludex Apr 25, 2025
3483663
Merge branch 'main' into support-thinking
Kludex Apr 25, 2025
e5e901f
Fix openai provider streaming
Kludex Apr 25, 2025
bf03ecd
Add basic documentation
Kludex Apr 25, 2025
006d17a
Add more coverage
Kludex Apr 25, 2025
07cfe72
Merge remote-tracking branch 'origin/main' into support-thinking
Kludex Apr 25, 2025
19c275e
Apply changes
Kludex Apr 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
args: ['--skip', 'tests/models/cassettes/*']
args: ['--skip', 'tests/models/cassettes/*', '--skip', 'tests/models/test_cohere.py']
additional_dependencies:
- tomli
3 changes: 3 additions & 0 deletions docs/thinking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Thinking

Also known as reasoning, "thinking" is the process of using a model's capabilities to reason about a task.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ nav:
- graph.md
- evals.md
- input.md
- thinking.md
- MCP:
- mcp/index.md
- mcp/client.md
Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ async def stream(
async for _event in stream:
pass

async def _run_stream(
async def _run_stream( # noqa: C901
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[_messages.HandleResponseEvent]:
if self._events_iterator is None:
Expand All @@ -413,6 +413,12 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
elif isinstance(part, _messages.ThinkingPart):
# We don't need to do anything with thinking parts in this tool-calling node.
# We need to handle text parts in case there are no tool calls and/or the desired output comes
# from the text, but thinking parts should not directly influence the execution of tools or
# determination of the next node of graph execution here.
pass
else:
assert_never(part)

Expand Down
76 changes: 74 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
PartStartEvent,
TextPart,
TextPartDelta,
ThinkingPart,
ThinkingPartDelta,
ToolCallPart,
ToolCallPartDelta,
)
Expand Down Expand Up @@ -86,8 +88,7 @@ def handle_text_delta(
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.

Raises:
UnexpectedModelBehavior: If attempting to apply text content to a part that is
not a TextPart.
UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart.
"""
existing_text_part_and_index: tuple[TextPart, int] | None = None

Expand Down Expand Up @@ -122,6 +123,77 @@ def handle_text_delta(
self._parts[part_index] = part_delta.apply(existing_text_part)
return PartDeltaEvent(index=part_index, delta=part_delta)

def handle_thinking_delta(
self,
*,
vendor_part_id: Hashable | None,
content: str | None = None,
signature: str | None = None,
) -> ModelResponseStreamEvent:
"""Handle incoming thinking content, creating or updating a ThinkingPart in the manager as appropriate.

When `vendor_part_id` is None, the latest part is updated if it exists and is a ThinkingPart;
otherwise, a new ThinkingPart is created. When a non-None ID is specified, the ThinkingPart corresponding
to that vendor ID is either created or updated.

Args:
vendor_part_id: The ID the vendor uses to identify this piece
of thinking. If None, a new part will be created unless the latest part is already
a ThinkingPart.
content: The thinking content to append to the appropriate ThinkingPart.
signature: An optional signature for the thinking content.

Returns:
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.

Raises:
UnexpectedModelBehavior: If attempting to apply a thinking delta to a part that is not a ThinkingPart.
"""
existing_thinking_part_and_index: tuple[ThinkingPart, int] | None = None

if vendor_part_id is None:
# If the vendor_part_id is None, check if the latest part is a ThinkingPart to update
if self._parts:
part_index = len(self._parts) - 1
latest_part = self._parts[part_index]
if isinstance(latest_part, ThinkingPart):
existing_thinking_part_and_index = latest_part, part_index
else:
# Otherwise, attempt to look up an existing ThinkingPart by vendor_part_id
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]
if not isinstance(existing_part, ThinkingPart):
raise UnexpectedModelBehavior(f'Cannot apply a thinking delta to {existing_part=}')
existing_thinking_part_and_index = existing_part, part_index

if existing_thinking_part_and_index is None:
if content is not None:
# There is no existing thinking part that should be updated, so create a new one
new_part_index = len(self._parts)
part = ThinkingPart(content=content, signature=signature)
if vendor_part_id is not None:
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
self._parts.append(part)
return PartStartEvent(index=new_part_index, part=part)
else:
raise UnexpectedModelBehavior('Cannot create a ThinkingPart with no content')
else:
if content is not None:
# Update the existing ThinkingPart with the new content delta
existing_thinking_part, part_index = existing_thinking_part_and_index
part_delta = ThinkingPartDelta(content_delta=content)
self._parts[part_index] = part_delta.apply(existing_thinking_part)
return PartDeltaEvent(index=part_index, delta=part_delta)
elif signature is not None:
# Update the existing ThinkingPart with the new signature delta
existing_thinking_part, part_index = existing_thinking_part_and_index
part_delta = ThinkingPartDelta(signature_delta=signature)
self._parts[part_index] = part_delta.apply(existing_thinking_part)
return PartDeltaEvent(index=part_index, delta=part_delta)
else:
raise UnexpectedModelBehavior('Cannot update a ThinkingPart with no content or signature')

def handle_tool_call_delta(
self,
*,
Expand Down
32 changes: 32 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_thinking_part.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations as _annotations

from pydantic_ai.messages import TextPart, ThinkingPart

START_THINK_TAG = '<think>'
END_THINK_TAG = '</think>'


def split_content_into_text_and_thinking(content: str) -> list[ThinkingPart | TextPart]:
"""Split a string into text and thinking parts.

Some models don't return the thinking part as a separate part, but rather as a tag in the content.
This function splits the content into text and thinking parts.

We use the `<think>` tag because that's how Groq uses it in the `raw` format, so instead of using `<Thinking>` or
something else, we just match the tag to make it easier for other models that don't support the `ThinkingPart`.
"""
parts: list[ThinkingPart | TextPart] = []
while START_THINK_TAG in content:
before_think, content = content.split(START_THINK_TAG, 1)
if before_think.strip():
parts.append(TextPart(content=before_think))
if END_THINK_TAG in content:
think_content, content = content.split(END_THINK_TAG, 1)
parts.append(ThinkingPart(content=think_content))
else:
# We lose the `<think>` tag, but it shouldn't matter.
parts.append(TextPart(content=content))
content = ''
if content:
parts.append(TextPart(content=content))
return parts
72 changes: 70 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,24 @@ def has_content(self) -> bool:
return bool(self.content)


@dataclass
class ThinkingPart:
"""A thinking response from a model."""

content: str
"""The thinking content of the response."""

signature: str | None = None
"""The signature of the thinking."""

part_kind: Literal['thinking'] = 'thinking'
"""Part type identifier, this is available on all parts as a discriminator."""

def has_content(self) -> bool:
"""Return `True` if the thinking content is non-empty."""
return bool(self.content)


@dataclass
class ToolCallPart:
"""A tool call from a model."""
Expand Down Expand Up @@ -540,7 +558,7 @@ def has_content(self) -> bool:
return bool(self.args)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydantic.Discriminator('part_kind')]
"""A message part returned by a model."""


Expand Down Expand Up @@ -630,6 +648,54 @@ def apply(self, part: ModelResponsePart) -> TextPart:
return replace(part, content=part.content + self.content_delta)


@dataclass
class ThinkingPartDelta:
"""A partial update (delta) for a `ThinkingPart` to append new thinking content."""

content_delta: str | None = None
"""The incremental thinking content to add to the existing `ThinkingPart` content."""

signature_delta: str | None = None
"""Optional signature delta.

Note this is never treated as a delta — it can replace None.
"""

part_delta_kind: Literal['thinking'] = 'thinking'
"""Part delta type identifier, used as a discriminator."""

@overload
def apply(self, part: ModelResponsePart) -> ThinkingPart: ...

@overload
def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | ThinkingPartDelta: ...

def apply(self, part: ModelResponsePart | ThinkingPartDelta) -> ThinkingPart | ThinkingPartDelta:
"""Apply this thinking delta to an existing `ThinkingPart`.

Args:
part: The existing model response part, which must be a `ThinkingPart`.

Returns:
A new `ThinkingPart` with updated thinking content.

Raises:
ValueError: If `part` is not a `ThinkingPart`.
"""
if isinstance(part, ThinkingPart):
return replace(part, content=part.content + self.content_delta if self.content_delta else None)
elif isinstance(part, ThinkingPartDelta):
if self.content_delta is None and self.signature_delta is None:
raise ValueError('Cannot apply ThinkingPartDelta with no content or signature')
if self.signature_delta is not None:
return replace(part, signature_delta=self.signature_delta)
if self.content_delta is not None:
return replace(part, content_delta=self.content_delta)
raise ValueError(
f'Cannot apply ThinkingPartDeltas to non-ThinkingParts or non-ThinkingPartDeltas ({part=}, {self=})'
)


@dataclass
class ToolCallPartDelta:
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
Expand Down Expand Up @@ -745,7 +811,9 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
return part


ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
ModelResponsePartDelta = Annotated[
Union[TextPartDelta, ThinkingPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')
]
"""A partial update (delta) for any model response part."""


Expand Down
Loading
Loading