diff --git a/Cargo.lock b/Cargo.lock index ab3f24e89..ef7cadc5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5889,7 +5889,7 @@ dependencies = [ [[package]] name = "shell-tool-approvals" -version = "0.3.68" +version = "0.3.71" dependencies = [ "globset", "log", diff --git a/libs/mcp/proxy/src/server/mod.rs b/libs/mcp/proxy/src/server/mod.rs index aa7059366..900ffd658 100644 --- a/libs/mcp/proxy/src/server/mod.rs +++ b/libs/mcp/proxy/src/server/mod.rs @@ -1,16 +1,16 @@ use rmcp::model::ServerCapabilities; -use rmcp::service::{NotificationContext, Peer, RequestContext}; +use rmcp::service::{NotificationContext, Peer, PeerRequestOptions, RequestContext}; use rmcp::transport::streamable_http_server::{ StreamableHttpService, session::local::LocalSessionManager, }; use rmcp::{ RoleClient, RoleServer, ServerHandler, ServiceError, model::{ - CallToolRequestParam, CallToolResult, CancelledNotificationParam, Content, ErrorData, - GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam, + CallToolRequestParam, CallToolResult, CancelledNotificationParam, ClientRequest, Content, + ErrorData, GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam, InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult, PaginatedRequestParam, ProtocolVersion, ReadResourceRequestParam, - ReadResourceResult, RequestId, + ReadResourceResult, Request, RequestId, ServerResult, }, }; @@ -125,10 +125,16 @@ fn restore_secrets_in_json_value( } } +#[derive(Debug, Clone)] +struct RequestTracking { + client_name: String, + upstream_request_id: Option, +} + pub struct ProxyServer { pool: Arc, - // Map downstream request IDs to upstream client names - request_id_to_client: Arc>>, + // Map downstream request IDs to upstream request tracking data. + request_tracking: Arc>>, // Configuration for upstream clients client_config: Arc>>, // Track if upstream clients have been initialized @@ -141,7 +147,7 @@ impl ProxyServer { pub fn new(config: ClientPoolConfig, redact_secrets: bool, privacy_mode: bool) -> Self { Self { pool: Arc::new(ClientPool::new()), - request_id_to_client: Arc::new(Mutex::new(HashMap::new())), + request_tracking: Arc::new(Mutex::new(HashMap::new())), client_config: Arc::new(Mutex::new(Some(config))), clients_initialized: Arc::new(Mutex::new(false)), secret_manager: SecretManager::new(redact_secrets, privacy_mode), @@ -154,17 +160,37 @@ impl ProxyServer { *stored_config = Some(config); } - /// Track a request ID to client mapping for cancellation forwarding + /// Track a downstream request for cancellation forwarding. async fn track_request(&self, request_id: RequestId, client_name: String) { - self.request_id_to_client - .lock() - .await - .insert(request_id, client_name); + self.request_tracking.lock().await.insert( + request_id, + RequestTracking { + client_name, + upstream_request_id: None, + }, + ); + } + + /// Set the upstream request ID once it is known. + async fn set_upstream_request_id( + &self, + downstream_request_id: &RequestId, + upstream_request_id: RequestId, + ) { + let mut tracking = self.request_tracking.lock().await; + if let Some(entry) = tracking.get_mut(downstream_request_id) { + entry.upstream_request_id = Some(upstream_request_id); + } else { + tracing::debug!( + "No request tracking entry found while setting upstream request ID for downstream request: {:?}", + downstream_request_id + ); + } } - /// Remove and return the client name for a request ID - async fn untrack_request(&self, request_id: &RequestId) -> Option { - self.request_id_to_client.lock().await.remove(request_id) + /// Remove and return tracking data for a request ID. + async fn untrack_request(&self, request_id: &RequestId) -> Option { + self.request_tracking.lock().await.remove(request_id) } /// Aggregate results from all clients using a provided async operation. @@ -280,6 +306,21 @@ impl ProxyServer { client_peer: &Peer, tool_params: CallToolRequestParam, ) -> Result { + let request_handle = client_peer + .send_cancellable_request( + ClientRequest::CallToolRequest(Request::new(tool_params)), + PeerRequestOptions { + meta: Some(ctx.meta.clone()), + ..Default::default() + }, + ) + .await?; + + let request_handle_id = request_handle.id.clone(); + + self.set_upstream_request_id(&ctx.id, request_handle_id.clone()) + .await; + tokio::select! { biased; @@ -287,7 +328,7 @@ impl ProxyServer { // Forward cancellation to upstream server let _ = client_peer .notify_cancelled(CancelledNotificationParam { - request_id: ctx.id.clone(), + request_id: request_handle_id, reason: Some("Request cancelled by downstream client".to_string()), }) .await; @@ -297,7 +338,12 @@ impl ProxyServer { }) } - result = client_peer.call_tool(tool_params) => result + result = request_handle.await_response() => { + match result? { + ServerResult::CallToolResult(result) => Ok(result), + _ => Err(ServiceError::UnexpectedResponse), + } + } } } @@ -716,8 +762,8 @@ impl ServerHandler for ProxyServer { ) { let request_id = notification.request_id.clone(); - // Atomically get and remove the mapping - let Some(client_name) = self.untrack_request(&request_id).await else { + // Atomically get and remove the mapping. + let Some(tracking) = self.untrack_request(&request_id).await else { tracing::debug!( "Cancellation notification received but no request ID mapping found for: {:?}", request_id @@ -725,26 +771,45 @@ impl ServerHandler for ProxyServer { return; }; - // Get a cloned peer and forward cancellation - let Some(client_peer) = self.pool.get_client_peer(&client_name).await else { + // If cancellation arrives before upstream request ID assignment, + // execute_with_cancellation will still forward using request_handle.id when + // ctx.ct is observed as cancelled. + let Some(upstream_request_id) = tracking.upstream_request_id else { + tracing::debug!( + "Cancellation notification received before upstream request ID assignment for downstream request: {:?}", + request_id + ); + return; + }; + + // Get a cloned peer and forward cancellation with the upstream request ID. + let Some(client_peer) = self.pool.get_client_peer(&tracking.client_name).await else { tracing::warn!( "Cancellation notification received for unknown client: {}", - client_name + tracking.client_name ); return; }; - if let Err(e) = client_peer.notify_cancelled(notification).await { + let upstream_notification = CancelledNotificationParam { + request_id: upstream_request_id.clone(), + reason: notification.reason, + }; + + if let Err(e) = client_peer.notify_cancelled(upstream_notification).await { tracing::warn!( - "Failed to forward cancellation to upstream server {}: {:?}", - client_name, + "Failed to forward cancellation to upstream server {} (downstream id: {:?}, upstream id: {:?}): {:?}", + tracking.client_name, + request_id, + upstream_request_id, e ); } else { tracing::debug!( - "Forwarded cancellation for request {:?} to client {}", + "Forwarded cancellation for downstream request {:?} to client {} with upstream request {:?}", request_id, - client_name + tracking.client_name, + upstream_request_id ); } }