diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2871fa4e3..37854ae79 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -12,7 +12,7 @@ from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema -from typing_extensions import Literal, Never, TypeGuard, TypeVar, deprecated +from typing_extensions import Literal, Never, TypeIs, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -459,7 +459,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -468,7 +468,23 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, **_deprecated_kwargs: Never, - ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... + + @overload + def iter( + self, + user_prompt: str | Sequence[_messages.UserContent] | None, + *, + output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + **_deprecated_kwargs: Never, + ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @overload @deprecated('`result_type` is deprecated, use `output_type` instead.') @@ -1614,7 +1630,7 @@ def _prepare_output_schema( @staticmethod def is_model_request_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]: + ) -> TypeIs[_agent_graph.ModelRequestNode[T, S]]: """Check if the node is a `ModelRequestNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1624,7 +1640,7 @@ def is_model_request_node( @staticmethod def is_call_tools_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.CallToolsNode[T, S]]: + ) -> TypeIs[_agent_graph.CallToolsNode[T, S]]: """Check if the node is a `CallToolsNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1634,7 +1650,7 @@ def is_call_tools_node( @staticmethod def is_user_prompt_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]: + ) -> TypeIs[_agent_graph.UserPromptNode[T, S]]: """Check if the node is a `UserPromptNode`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. @@ -1644,7 +1660,7 @@ def is_user_prompt_node( @staticmethod def is_end_node( node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], - ) -> TypeGuard[End[result.FinalResult[S]]]: + ) -> TypeIs[End[result.FinalResult[S]]]: """Check if the node is a `End`, narrowing the type if it is. This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`.