|
11 | 11 | import contextlib |
12 | 12 | import logging |
13 | 13 | from collections.abc import Callable |
| 14 | +from dataclasses import dataclass |
14 | 15 | from datetime import timedelta |
15 | 16 | from types import TracebackType |
16 | | -from typing import Any, TypeAlias |
| 17 | +from typing import Any, TypeAlias, overload |
17 | 18 |
|
18 | 19 | import anyio |
19 | 20 | from pydantic import BaseModel |
20 | | -from typing_extensions import Self |
| 21 | +from typing_extensions import Self, deprecated |
21 | 22 |
|
22 | 23 | import mcp |
23 | 24 | from mcp import types |
| 25 | +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT |
24 | 26 | from mcp.client.sse import sse_client |
25 | 27 | from mcp.client.stdio import StdioServerParameters |
26 | 28 | from mcp.client.streamable_http import streamablehttp_client |
27 | 29 | from mcp.shared.exceptions import McpError |
| 30 | +from mcp.shared.session import ProgressFnT |
28 | 31 |
|
29 | 32 |
|
30 | 33 | class SseServerParameters(BaseModel): |
@@ -65,6 +68,21 @@ class StreamableHttpParameters(BaseModel): |
65 | 68 | ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters |
66 | 69 |
|
67 | 70 |
|
| 71 | +# Use dataclass instead of pydantic BaseModel |
| 72 | +# because pydantic BaseModel cannot handle Protocol fields. |
| 73 | +@dataclass |
| 74 | +class ClientSessionParameters: |
| 75 | + """Parameters for establishing a client session to an MCP server.""" |
| 76 | + |
| 77 | + read_timeout_seconds: timedelta | None = None |
| 78 | + sampling_callback: SamplingFnT | None = None |
| 79 | + elicitation_callback: ElicitationFnT | None = None |
| 80 | + list_roots_callback: ListRootsFnT | None = None |
| 81 | + logging_callback: LoggingFnT | None = None |
| 82 | + message_handler: MessageHandlerFnT | None = None |
| 83 | + client_info: types.Implementation | None = None |
| 84 | + |
| 85 | + |
68 | 86 | class ClientSessionGroup: |
69 | 87 | """Client for managing connections to multiple MCP servers. |
70 | 88 |
|
@@ -172,11 +190,49 @@ def tools(self) -> dict[str, types.Tool]: |
172 | 190 | """Returns the tools as a dictionary of names to tools.""" |
173 | 191 | return self._tools |
174 | 192 |
|
175 | | - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: |
| 193 | + @overload |
| 194 | + async def call_tool( |
| 195 | + self, |
| 196 | + name: str, |
| 197 | + arguments: dict[str, Any], |
| 198 | + read_timeout_seconds: timedelta | None = None, |
| 199 | + progress_callback: ProgressFnT | None = None, |
| 200 | + *, |
| 201 | + meta: dict[str, Any] | None = None, |
| 202 | + ) -> types.CallToolResult: ... |
| 203 | + |
| 204 | + @overload |
| 205 | + @deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.") |
| 206 | + async def call_tool( |
| 207 | + self, |
| 208 | + name: str, |
| 209 | + *, |
| 210 | + args: dict[str, Any], |
| 211 | + read_timeout_seconds: timedelta | None = None, |
| 212 | + progress_callback: ProgressFnT | None = None, |
| 213 | + meta: dict[str, Any] | None = None, |
| 214 | + ) -> types.CallToolResult: ... |
| 215 | + |
| 216 | + async def call_tool( |
| 217 | + self, |
| 218 | + name: str, |
| 219 | + arguments: dict[str, Any] | None = None, |
| 220 | + read_timeout_seconds: timedelta | None = None, |
| 221 | + progress_callback: ProgressFnT | None = None, |
| 222 | + *, |
| 223 | + meta: dict[str, Any] | None = None, |
| 224 | + args: dict[str, Any] | None = None, |
| 225 | + ) -> types.CallToolResult: |
176 | 226 | """Executes a tool given its name and arguments.""" |
177 | 227 | session = self._tool_to_session[name] |
178 | 228 | session_tool_name = self.tools[name].name |
179 | | - return await session.call_tool(session_tool_name, args) |
| 229 | + return await session.call_tool( |
| 230 | + session_tool_name, |
| 231 | + arguments if args is None else args, |
| 232 | + read_timeout_seconds=read_timeout_seconds, |
| 233 | + progress_callback=progress_callback, |
| 234 | + meta=meta, |
| 235 | + ) |
180 | 236 |
|
181 | 237 | async def disconnect_from_server(self, session: mcp.ClientSession) -> None: |
182 | 238 | """Disconnects from a single MCP server.""" |
@@ -225,13 +281,16 @@ async def connect_with_session( |
225 | 281 | async def connect_to_server( |
226 | 282 | self, |
227 | 283 | server_params: ServerParameters, |
| 284 | + session_params: ClientSessionParameters | None = None, |
228 | 285 | ) -> mcp.ClientSession: |
229 | 286 | """Connects to a single MCP server.""" |
230 | | - server_info, session = await self._establish_session(server_params) |
| 287 | + server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters()) |
231 | 288 | return await self.connect_with_session(server_info, session) |
232 | 289 |
|
233 | 290 | async def _establish_session( |
234 | | - self, server_params: ServerParameters |
| 291 | + self, |
| 292 | + server_params: ServerParameters, |
| 293 | + session_params: ClientSessionParameters, |
235 | 294 | ) -> tuple[types.Implementation, mcp.ClientSession]: |
236 | 295 | """Establish a client session to an MCP server.""" |
237 | 296 |
|
@@ -259,7 +318,20 @@ async def _establish_session( |
259 | 318 | ) |
260 | 319 | read, write, _ = await session_stack.enter_async_context(client) |
261 | 320 |
|
262 | | - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) |
| 321 | + session = await session_stack.enter_async_context( |
| 322 | + mcp.ClientSession( |
| 323 | + read, |
| 324 | + write, |
| 325 | + read_timeout_seconds=session_params.read_timeout_seconds, |
| 326 | + sampling_callback=session_params.sampling_callback, |
| 327 | + elicitation_callback=session_params.elicitation_callback, |
| 328 | + list_roots_callback=session_params.list_roots_callback, |
| 329 | + logging_callback=session_params.logging_callback, |
| 330 | + message_handler=session_params.message_handler, |
| 331 | + client_info=session_params.client_info, |
| 332 | + ) |
| 333 | + ) |
| 334 | + |
263 | 335 | result = await session.initialize() |
264 | 336 |
|
265 | 337 | # Session successfully initialized. |
|
0 commit comments