Skip to content

Commit a357380

Browse files
authored
feat: Pass through and expose additional parameters in ClientSessionGroup.call_tool and .connect_to_server (#1576)
1 parent 9724ad1 commit a357380

File tree

2 files changed

+101
-11
lines changed

2 files changed

+101
-11
lines changed

src/mcp/client/session_group.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,23 @@
1111
import contextlib
1212
import logging
1313
from collections.abc import Callable
14+
from dataclasses import dataclass
1415
from datetime import timedelta
1516
from types import TracebackType
16-
from typing import Any, TypeAlias
17+
from typing import Any, TypeAlias, overload
1718

1819
import anyio
1920
from pydantic import BaseModel
20-
from typing_extensions import Self
21+
from typing_extensions import Self, deprecated
2122

2223
import mcp
2324
from mcp import types
25+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2426
from mcp.client.sse import sse_client
2527
from mcp.client.stdio import StdioServerParameters
2628
from mcp.client.streamable_http import streamablehttp_client
2729
from mcp.shared.exceptions import McpError
30+
from mcp.shared.session import ProgressFnT
2831

2932

3033
class SseServerParameters(BaseModel):
@@ -65,6 +68,21 @@ class StreamableHttpParameters(BaseModel):
6568
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
6669

6770

71+
# Use dataclass instead of pydantic BaseModel
72+
# because pydantic BaseModel cannot handle Protocol fields.
73+
@dataclass
74+
class ClientSessionParameters:
75+
"""Parameters for establishing a client session to an MCP server."""
76+
77+
read_timeout_seconds: timedelta | None = None
78+
sampling_callback: SamplingFnT | None = None
79+
elicitation_callback: ElicitationFnT | None = None
80+
list_roots_callback: ListRootsFnT | None = None
81+
logging_callback: LoggingFnT | None = None
82+
message_handler: MessageHandlerFnT | None = None
83+
client_info: types.Implementation | None = None
84+
85+
6886
class ClientSessionGroup:
6987
"""Client for managing connections to multiple MCP servers.
7088
@@ -172,11 +190,49 @@ def tools(self) -> dict[str, types.Tool]:
172190
"""Returns the tools as a dictionary of names to tools."""
173191
return self._tools
174192

175-
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
193+
@overload
194+
async def call_tool(
195+
self,
196+
name: str,
197+
arguments: dict[str, Any],
198+
read_timeout_seconds: timedelta | None = None,
199+
progress_callback: ProgressFnT | None = None,
200+
*,
201+
meta: dict[str, Any] | None = None,
202+
) -> types.CallToolResult: ...
203+
204+
@overload
205+
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
206+
async def call_tool(
207+
self,
208+
name: str,
209+
*,
210+
args: dict[str, Any],
211+
read_timeout_seconds: timedelta | None = None,
212+
progress_callback: ProgressFnT | None = None,
213+
meta: dict[str, Any] | None = None,
214+
) -> types.CallToolResult: ...
215+
216+
async def call_tool(
217+
self,
218+
name: str,
219+
arguments: dict[str, Any] | None = None,
220+
read_timeout_seconds: timedelta | None = None,
221+
progress_callback: ProgressFnT | None = None,
222+
*,
223+
meta: dict[str, Any] | None = None,
224+
args: dict[str, Any] | None = None,
225+
) -> types.CallToolResult:
176226
"""Executes a tool given its name and arguments."""
177227
session = self._tool_to_session[name]
178228
session_tool_name = self.tools[name].name
179-
return await session.call_tool(session_tool_name, args)
229+
return await session.call_tool(
230+
session_tool_name,
231+
arguments if args is None else args,
232+
read_timeout_seconds=read_timeout_seconds,
233+
progress_callback=progress_callback,
234+
meta=meta,
235+
)
180236

181237
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
182238
"""Disconnects from a single MCP server."""
@@ -225,13 +281,16 @@ async def connect_with_session(
225281
async def connect_to_server(
226282
self,
227283
server_params: ServerParameters,
284+
session_params: ClientSessionParameters | None = None,
228285
) -> mcp.ClientSession:
229286
"""Connects to a single MCP server."""
230-
server_info, session = await self._establish_session(server_params)
287+
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
231288
return await self.connect_with_session(server_info, session)
232289

233290
async def _establish_session(
234-
self, server_params: ServerParameters
291+
self,
292+
server_params: ServerParameters,
293+
session_params: ClientSessionParameters,
235294
) -> tuple[types.Implementation, mcp.ClientSession]:
236295
"""Establish a client session to an MCP server."""
237296

@@ -259,7 +318,20 @@ async def _establish_session(
259318
)
260319
read, write, _ = await session_stack.enter_async_context(client)
261320

262-
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
321+
session = await session_stack.enter_async_context(
322+
mcp.ClientSession(
323+
read,
324+
write,
325+
read_timeout_seconds=session_params.read_timeout_seconds,
326+
sampling_callback=session_params.sampling_callback,
327+
elicitation_callback=session_params.elicitation_callback,
328+
list_roots_callback=session_params.list_roots_callback,
329+
logging_callback=session_params.logging_callback,
330+
message_handler=session_params.message_handler,
331+
client_info=session_params.client_info,
332+
)
333+
)
334+
263335
result = await session.initialize()
264336

265337
# Session successfully initialized.

tests/client/test_session_group.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66
import mcp
77
from mcp import types
8-
from mcp.client.session_group import ClientSessionGroup, SseServerParameters, StreamableHttpParameters
8+
from mcp.client.session_group import (
9+
ClientSessionGroup,
10+
ClientSessionParameters,
11+
SseServerParameters,
12+
StreamableHttpParameters,
13+
)
914
from mcp.client.stdio import StdioServerParameters
1015
from mcp.shared.exceptions import McpError
1116

@@ -62,7 +67,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov
6267
# --- Test Execution ---
6368
result = await mcp_session_group.call_tool(
6469
name="server1-my_tool",
65-
args={
70+
arguments={
6671
"name": "value1",
6772
"args": {},
6873
},
@@ -73,6 +78,9 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov
7378
mock_session.call_tool.assert_called_once_with(
7479
"my_tool",
7580
{"name": "value1", "args": {}},
81+
read_timeout_seconds=None,
82+
progress_callback=None,
83+
meta=None,
7684
)
7785

7886
async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):
@@ -329,7 +337,7 @@ async def test_establish_session_parameterized(
329337
(
330338
returned_server_info,
331339
returned_session,
332-
) = await group._establish_session(server_params_instance)
340+
) = await group._establish_session(server_params_instance, ClientSessionParameters())
333341

334342
# --- Assertions ---
335343
# 1. Assert the correct specific client function was called
@@ -357,7 +365,17 @@ async def test_establish_session_parameterized(
357365
mock_client_cm_instance.__aenter__.assert_awaited_once()
358366

359367
# 2. Assert ClientSession was called correctly
360-
mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream)
368+
mock_ClientSession_class.assert_called_once_with(
369+
mock_read_stream,
370+
mock_write_stream,
371+
read_timeout_seconds=None,
372+
sampling_callback=None,
373+
elicitation_callback=None,
374+
list_roots_callback=None,
375+
logging_callback=None,
376+
message_handler=None,
377+
client_info=None,
378+
)
361379
mock_raw_session_cm.__aenter__.assert_awaited_once()
362380
mock_entered_session.initialize.assert_awaited_once()
363381

0 commit comments

Comments
 (0)