Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

119 changes: 92 additions & 27 deletions libs/mcp/proxy/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use rmcp::model::ServerCapabilities;
use rmcp::service::{NotificationContext, Peer, RequestContext};
use rmcp::service::{NotificationContext, Peer, PeerRequestOptions, RequestContext};
use rmcp::transport::streamable_http_server::{
StreamableHttpService, session::local::LocalSessionManager,
};
use rmcp::{
RoleClient, RoleServer, ServerHandler, ServiceError,
model::{
CallToolRequestParam, CallToolResult, CancelledNotificationParam, Content, ErrorData,
GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam,
CallToolRequestParam, CallToolResult, CancelledNotificationParam, ClientRequest, Content,
ErrorData, GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam,
InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult,
ListToolsResult, PaginatedRequestParam, ProtocolVersion, ReadResourceRequestParam,
ReadResourceResult, RequestId,
ReadResourceResult, Request, RequestId, ServerResult,
},
};

Expand Down Expand Up @@ -125,10 +125,16 @@ fn restore_secrets_in_json_value(
}
}

#[derive(Debug, Clone)]
struct RequestTracking {
client_name: String,
upstream_request_id: Option<RequestId>,
}

pub struct ProxyServer {
pool: Arc<ClientPool>,
// Map downstream request IDs to upstream client names
request_id_to_client: Arc<Mutex<HashMap<RequestId, String>>>,
// Map downstream request IDs to upstream request tracking data.
request_tracking: Arc<Mutex<HashMap<RequestId, RequestTracking>>>,
// Configuration for upstream clients
client_config: Arc<Mutex<Option<ClientPoolConfig>>>,
// Track if upstream clients have been initialized
Expand All @@ -141,7 +147,7 @@ impl ProxyServer {
pub fn new(config: ClientPoolConfig, redact_secrets: bool, privacy_mode: bool) -> Self {
Self {
pool: Arc::new(ClientPool::new()),
request_id_to_client: Arc::new(Mutex::new(HashMap::new())),
request_tracking: Arc::new(Mutex::new(HashMap::new())),
client_config: Arc::new(Mutex::new(Some(config))),
clients_initialized: Arc::new(Mutex::new(false)),
secret_manager: SecretManager::new(redact_secrets, privacy_mode),
Expand All @@ -154,17 +160,37 @@ impl ProxyServer {
*stored_config = Some(config);
}

/// Track a request ID to client mapping for cancellation forwarding
/// Track a downstream request for cancellation forwarding.
async fn track_request(&self, request_id: RequestId, client_name: String) {
self.request_id_to_client
.lock()
.await
.insert(request_id, client_name);
self.request_tracking.lock().await.insert(
request_id,
RequestTracking {
client_name,
upstream_request_id: None,
},
);
}

/// Set the upstream request ID once it is known.
async fn set_upstream_request_id(
&self,
downstream_request_id: &RequestId,
upstream_request_id: RequestId,
) {
let mut tracking = self.request_tracking.lock().await;
if let Some(entry) = tracking.get_mut(downstream_request_id) {
entry.upstream_request_id = Some(upstream_request_id);
} else {
tracing::debug!(
"No request tracking entry found while setting upstream request ID for downstream request: {:?}",
downstream_request_id
);
}
}

/// Remove and return the client name for a request ID
async fn untrack_request(&self, request_id: &RequestId) -> Option<String> {
self.request_id_to_client.lock().await.remove(request_id)
/// Remove and return tracking data for a request ID.
async fn untrack_request(&self, request_id: &RequestId) -> Option<RequestTracking> {
self.request_tracking.lock().await.remove(request_id)
}

/// Aggregate results from all clients using a provided async operation.
Expand Down Expand Up @@ -280,14 +306,29 @@ impl ProxyServer {
client_peer: &Peer<RoleClient>,
tool_params: CallToolRequestParam,
) -> Result<CallToolResult, ServiceError> {
let request_handle = client_peer
.send_cancellable_request(
ClientRequest::CallToolRequest(Request::new(tool_params)),
PeerRequestOptions {
meta: Some(ctx.meta.clone()),
..Default::default()
},
)
.await?;

let request_handle_id = request_handle.id.clone();

self.set_upstream_request_id(&ctx.id, request_handle_id.clone())
.await;

tokio::select! {
biased;

_ = ctx.ct.cancelled() => {
// Forward cancellation to upstream server
let _ = client_peer
.notify_cancelled(CancelledNotificationParam {
request_id: ctx.id.clone(),
request_id: request_handle_id,
reason: Some("Request cancelled by downstream client".to_string()),
})
.await;
Expand All @@ -297,7 +338,12 @@ impl ProxyServer {
})
}

result = client_peer.call_tool(tool_params) => result
result = request_handle.await_response() => {
match result? {
ServerResult::CallToolResult(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
}
}

Expand Down Expand Up @@ -716,35 +762,54 @@ impl ServerHandler for ProxyServer {
) {
let request_id = notification.request_id.clone();

// Atomically get and remove the mapping
let Some(client_name) = self.untrack_request(&request_id).await else {
// Atomically get and remove the mapping.
let Some(tracking) = self.untrack_request(&request_id).await else {
tracing::debug!(
"Cancellation notification received but no request ID mapping found for: {:?}",
request_id
);
return;
};

// Get a cloned peer and forward cancellation
let Some(client_peer) = self.pool.get_client_peer(&client_name).await else {
// If cancellation arrives before upstream request ID assignment,
// execute_with_cancellation will still forward using request_handle.id when
// ctx.ct is observed as cancelled.
let Some(upstream_request_id) = tracking.upstream_request_id else {
tracing::debug!(
"Cancellation notification received before upstream request ID assignment for downstream request: {:?}",
request_id
);
return;
};

// Get a cloned peer and forward cancellation with the upstream request ID.
let Some(client_peer) = self.pool.get_client_peer(&tracking.client_name).await else {
tracing::warn!(
"Cancellation notification received for unknown client: {}",
client_name
tracking.client_name
);
return;
};

if let Err(e) = client_peer.notify_cancelled(notification).await {
let upstream_notification = CancelledNotificationParam {
request_id: upstream_request_id.clone(),
reason: notification.reason,
};

if let Err(e) = client_peer.notify_cancelled(upstream_notification).await {
tracing::warn!(
"Failed to forward cancellation to upstream server {}: {:?}",
client_name,
"Failed to forward cancellation to upstream server {} (downstream id: {:?}, upstream id: {:?}): {:?}",
tracking.client_name,
request_id,
upstream_request_id,
e
);
} else {
tracing::debug!(
"Forwarded cancellation for request {:?} to client {}",
"Forwarded cancellation for downstream request {:?} to client {} with upstream request {:?}",
request_id,
client_name
tracking.client_name,
upstream_request_id
);
}
}
Expand Down
Loading