diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 4739e231a..e9947a111 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -210,6 +210,7 @@ public Mono notifyClients(String method, Object params) { @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + final McpTransportContext transportContext = this.contextExtractor.extract(request); String requestURI = request.getRequestURI(); if (!requestURI.endsWith(sseEndpoint)) { @@ -239,7 +240,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) writer); // Create a new session using the session factory - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(transportContext, sessionTransport); this.sessions.put(sessionId, session); // Send initial endpoint event diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 34671c105..b384c1c66 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -408,7 +408,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) new TypeRef() { }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory - .startSession(initializeRequest); + .startSession(transportContext, initializeRequest); this.sessions.put(init.session().getId(), init.session()); try { diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 68be62931..7f2dea720 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -91,7 +91,7 @@ public List protocolVersions() { public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection var transport = new StdioMcpSessionTransport(); - this.session = sessionFactory.create(transport); + this.session = sessionFactory.create(null, transport); transport.initProcessing(); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java index f497afd43..89b0c0fe9 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.spec; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; @@ -45,11 +46,33 @@ public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, @Override public McpStreamableServerSession.McpStreamableServerSessionInit startSession( - McpSchema.InitializeRequest initializeRequest) { + final McpSchema.InitializeRequest initializeRequest) { + final String sessionId = generateSessionId(null, initializeRequest); return new McpStreamableServerSession.McpStreamableServerSessionInit( - new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + new McpStreamableServerSession(sessionId, initializeRequest.capabilities(), initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), this.initRequestHandler.handle(initializeRequest)); } + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + final McpTransportContext mcpTransportContext, final McpSchema.InitializeRequest initializeRequest) { + final String sessionId = generateSessionId(mcpTransportContext, initializeRequest); + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(sessionId, initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } + + /** + * An extensibility point to generate session IDs differently. + * @param mcpTransportContext transport context + * @param initializeRequest initialization request + * @return generated session ID + */ + protected String generateSessionId(McpTransportContext mcpTransportContext, + McpSchema.InitializeRequest initializeRequest) { + return UUID.randomUUID().toString(); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 241f7d8b5..3b9e15a89 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -439,6 +439,16 @@ public interface Factory { */ McpServerSession create(McpServerTransport sessionTransport); + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param mcpTransportContext the transport context associated with the client. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + default McpServerSession create(McpTransportContext mcpTransportContext, McpServerTransport sessionTransport) { + return create(sessionTransport); + } + } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 95f8959f5..59a0389c8 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -298,8 +298,20 @@ public interface Factory { * @param initializeRequest the initialization request from the client * @return a composite allowing the session to start */ + @Deprecated McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + /** + * Given an initialize request, create a composite for the session initialization + * @param mcpTransportContext the transport context for the initialization request + * @param initializeRequest the initialization request from the client + * @return a composite allowing the session to start + */ + default McpStreamableServerSessionInit startSession(McpTransportContext mcpTransportContext, + McpSchema.InitializeRequest initializeRequest) { + return startSession(initializeRequest); + } + } /** diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index e955be89f..b43f11e7f 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -30,7 +30,7 @@ public MockMcpServerTransport getTransport() { @Override public void setSessionFactory(Factory sessionFactory) { - session = sessionFactory.create(transport); + session = sessionFactory.create(null, transport); } @Override diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 6a70af33d..595a58ded 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -14,6 +14,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -67,7 +68,8 @@ void setUp() { sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior - when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); + when(sessionFactory.create(any(McpTransportContext.class), any(McpServerTransport.class))) + .thenReturn(mockSession); when(mockSession.closeGracefully()).thenReturn(Mono.empty()); when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 95355c0f2..76c9df179 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -278,7 +278,7 @@ private Mono handleSseConnection(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxMcpSessionTransport sessionTransport = new WebFluxMcpSessionTransport(sink); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(transportContext, sessionTransport); String sessionId = session.getId(); logger.debug("Created new SSE connection for session: {}", sessionId); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index deebfc616..86c949721 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -239,7 +239,7 @@ private Mono handlePost(ServerRequest request) { McpSchema.InitializeRequest initializeRequest = jsonMapper.convertValue(jsonrpcRequest.params(), typeReference); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory - .startSession(initializeRequest); + .startSession(transportContext, initializeRequest); sessions.put(init.session().getId(), init.session()); return init.initResult().map(initializeResult -> { McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 0b71ddc1f..ece96d415 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -255,6 +255,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + McpTransportContext mcpTransportContext = this.contextExtractor.extract(request); String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); @@ -271,7 +272,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { }); WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, sseBuilder); - McpServerSession session = sessionFactory.create(sessionTransport); + McpServerSession session = sessionFactory.create(mcpTransportContext, sessionTransport); this.sessions.put(sessionId, session); try { diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index f2a58d4d8..6d9dc1857 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -335,7 +335,7 @@ private ServerResponse handlePost(ServerRequest request) { new TypeRef() { }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory - .startSession(initializeRequest); + .startSession(transportContext, initializeRequest); this.sessions.put(init.session().getId(), init.session()); try {