diff --git a/tonic-xds/Cargo.toml b/tonic-xds/Cargo.toml index e081cf991..bf93d0223 100644 --- a/tonic-xds/Cargo.toml +++ b/tonic-xds/Cargo.toml @@ -24,12 +24,14 @@ exclude = ["proto/test/*"] tonic = "0.14" http = "1" http-body = "1" +pin-project-lite = "0.2" tower = { version = "0.5", default-features = false, features = ["discover", "retry"] } arc-swap = "1" dashmap = "6.1" thiserror = "2.0.17" url = "2.5.8" futures-core = "0.3.31" +futures-util = "0.3" bytes = "1" xds-client = { version = "0.1.0-alpha.1", path = "../xds-client" } serde = { version = "1", features = ["derive"] } @@ -37,11 +39,12 @@ serde_json = "1" envoy-types = "0.7" prost = "0.14" regex = "1" -tokio = { version = "1", features = ["sync"] } +tokio = { version = "1", features = ["sync", "time"] } # Used for weighted cluster selection and fractional route matching — does not need # cryptographic security, only statistical uniformity for traffic distribution. fastrand = "2" tokio-stream = "0.1" +tokio-util = "0.7" backoff = "0.4" shared_http_body = "0.1" tonic-prost = { version = "0.14", optional = true } @@ -51,7 +54,7 @@ workspace = true [dev-dependencies] xds-client = { version = "0.1.0-alpha.1", path = "../xds-client", features = ["test-util"] } -tokio = { version = "1", features = ["rt-multi-thread", "macros", "net"] } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "net", "test-util"] } tonic = { version = "0.14", features = [ "server", "channel", "tls-ring" ] } tonic-prost = "0.14" tonic-prost-build = "0.14" diff --git a/tonic-xds/src/client/endpoint.rs b/tonic-xds/src/client/endpoint.rs index 82612f496..fbe0869b3 100644 --- a/tonic-xds/src/client/endpoint.rs +++ b/tonic-xds/src/client/endpoint.rs @@ -153,3 +153,20 @@ impl Load for EndpointChannel { self.in_flight.load(Ordering::Relaxed) } } + +/// Factory for creating connections to endpoints. +/// +/// Implementations capture cluster-level config (TLS, HTTP/2 settings, timeouts) +/// at construction time. The implementation handles retries and concurrency +/// internally — the returned future resolves when a connection is established +/// (or is cancelled by dropping). +pub(crate) trait Connector { + /// The service type produced by this connector. + type Service; + + /// Connect to the given endpoint address. + fn connect( + &self, + addr: &EndpointAddress, + ) -> crate::common::async_util::BoxFuture; +} diff --git a/tonic-xds/src/client/loadbalance/channel.rs b/tonic-xds/src/client/loadbalance/channel.rs new file mode 100644 index 000000000..85368d28f --- /dev/null +++ b/tonic-xds/src/client/loadbalance/channel.rs @@ -0,0 +1,199 @@ +//! LbChannel: an instrumented channel wrapper with in-flight request tracking. + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use tower::Service; +use tower::load::Load; + +use crate::client::endpoint::EndpointAddress; + +/// RAII guard that increments an in-flight counter on creation and decrements on drop. +/// Ensures accurate tracking even when futures are cancelled. +struct InFlightGuard { + counter: Arc, +} + +impl InFlightGuard { + fn acquire(counter: Arc) -> Self { + counter.fetch_add(1, Ordering::Relaxed); + Self { counter } + } +} + +impl Drop for InFlightGuard { + fn drop(&mut self) { + self.counter.fetch_sub(1, Ordering::Relaxed); + } +} + +pin_project! { + /// A future that holds an [`InFlightGuard`] for the duration of a request. + /// + /// Preserves the inner future's output type — no boxing or error mapping. + pub(crate) struct InFlightFuture { + #[pin] + inner: F, + _guard: InFlightGuard, + } +} + +impl Future for InFlightFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +/// A channel wrapper that tracks in-flight requests for load balancing. +/// +/// `LbChannel` wraps an inner service `S` and maintains an atomic counter of +/// in-flight requests. This counter is used by P2C load balancers (via the +/// [`Load`] trait) to prefer endpoints with fewer active requests. +/// +/// All clones of an `LbChannel` share the same in-flight counter. +pub(crate) struct LbChannel { + addr: EndpointAddress, + inner: S, + in_flight: Arc, +} + +impl LbChannel { + /// Create a new `LbChannel` wrapping the given service. + pub(crate) fn new(addr: EndpointAddress, inner: S) -> Self { + Self { + addr, + inner, + in_flight: Arc::new(AtomicU64::new(0)), + } + } + + /// Returns the endpoint address. + pub(crate) fn addr(&self) -> &EndpointAddress { + &self.addr + } + + /// Returns the current number of in-flight requests. + #[cfg(test)] + pub(crate) fn in_flight(&self) -> u64 { + self.in_flight.load(Ordering::Relaxed) + } +} + +impl Clone for LbChannel { + fn clone(&self) -> Self { + Self { + addr: self.addr.clone(), + inner: self.inner.clone(), + in_flight: self.in_flight.clone(), + } + } +} + +impl Service for LbChannel +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = InFlightFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + let guard = InFlightGuard::acquire(self.in_flight.clone()); + InFlightFuture { + inner: self.inner.call(req), + _guard: guard, + } + } +} + +impl Load for LbChannel { + type Metric = u64; + + fn load(&self) -> Self::Metric { + self.in_flight.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::future; + use std::task::Poll; + + fn test_addr() -> EndpointAddress { + EndpointAddress::new("127.0.0.1", 8080) + } + + #[derive(Clone)] + struct MockService; + + impl Service<&'static str> for MockService { + type Response = &'static str; + type Error = &'static str; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: &'static str) -> Self::Future { + future::ready(Ok("ok")) + } + } + + #[tokio::test] + async fn test_in_flight_increments_and_decrements() { + let mut ch = LbChannel::new(test_addr(), MockService); + assert_eq!(ch.in_flight(), 0); + + let fut = ch.call("hello"); + assert_eq!(ch.in_flight(), 1); + + let resp = fut.await.unwrap(); + assert_eq!(resp, "ok"); + assert_eq!(ch.in_flight(), 0); + } + + #[tokio::test] + async fn test_in_flight_on_future_drop() { + let mut ch = LbChannel::new(test_addr(), MockService); + let fut = ch.call("hello"); + assert_eq!(ch.in_flight(), 1); + + drop(fut); + assert_eq!(ch.in_flight(), 0); + } + + #[tokio::test] + async fn test_clone_shares_in_flight() { + let mut ch1 = LbChannel::new(test_addr(), MockService); + let ch2 = ch1.clone(); + + let fut = ch1.call("hello"); + assert_eq!(ch1.in_flight(), 1); + assert_eq!(ch2.in_flight(), 1); + + let _ = fut.await; + assert_eq!(ch1.in_flight(), 0); + assert_eq!(ch2.in_flight(), 0); + } + + #[test] + fn test_load_returns_in_flight() { + let ch = LbChannel::new(test_addr(), MockService); + assert_eq!(Load::load(&ch), 0); + + ch.in_flight.fetch_add(3, Ordering::Relaxed); + assert_eq!(Load::load(&ch), 3); + } +} diff --git a/tonic-xds/src/client/loadbalance/channel_state.rs b/tonic-xds/src/client/loadbalance/channel_state.rs new file mode 100644 index 000000000..8f55d516a --- /dev/null +++ b/tonic-xds/src/client/loadbalance/channel_state.rs @@ -0,0 +1,390 @@ +//! Type-state wrappers for LbChannel lifecycle management. +//! +//! Each state is a separate struct, and transitions consume the old state (move semantics). +//! This prevents using a channel in an invalid state at compile time. +//! +//! ```text +//! +-----------+ +//! | | +//! v | +//! Idle --> Connecting --> Ready <--+--> Ejected +//! ^ | +//! | | +//! +-----------------------+ +//! ``` +//! +//! State changes are all one-shot. [`ConnectingChannel`] and [`EjectedChannel`] are +//! [`Future`]. The caller (typically a pool) uses [`KeyedFutures`] to +//! manage multiple in-flight state changes and handle cancellation by key. +//! +//! The state types hold the raw service `S` directly. In-flight tracking and +//! load reporting are handled separately by [`LbChannel`] at the pool level. +//! +//! [`KeyedFutures`]: crate::client::loadbalance::keyed_futures::KeyedFutures +//! [`LbChannel`]: crate::client::loadbalance::channel::LbChannel + +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use pin_project_lite::pin_project; +use tower::Service; + +use crate::client::endpoint::{Connector, EndpointAddress}; +use crate::common::async_util::BoxFuture; + +/// Configuration for an ejected channel. +#[derive(Debug, Clone)] +pub(crate) struct EjectionConfig { + /// How long the channel is ejected before it can return to service. + pub timeout: Duration, + /// Whether the channel needs a fresh connection after ejection expires (e.g. after consecutive timeouts). + pub needs_reconnect: bool, +} + +/// Result of an ejection expiring. +pub(crate) enum UnejectedChannel { + /// The channel is ready to serve again (ejection expired, no reconnect needed). + Ready(ReadyChannel), + /// A fresh connection has been started. + Connecting(ConnectingChannel), +} + +// --------------------------------------------------------------------------- +// IdleChannel +// --------------------------------------------------------------------------- + +/// An idle channel that only stores an address. It is the entry point for +/// starting a connection attempt. +pub(crate) struct IdleChannel { + addr: EndpointAddress, +} + +impl IdleChannel { + pub(crate) fn new(addr: EndpointAddress) -> Self { + Self { addr } + } + + /// Start connecting to the endpoint. Consumes the idle channel. + pub(crate) fn connect(self, connector: Arc) -> ConnectingChannel + where + C::Service: Send + 'static, + { + ConnectingChannel::new(connector.connect(&self.addr), self.addr) + } +} + +// --------------------------------------------------------------------------- +// ConnectingChannel +// --------------------------------------------------------------------------- + +/// A channel that is in the process of connecting. +/// +/// Implements [`Future`] -- resolves to [`ReadyChannel`] when connected. +/// Cancellation is handled externally via [`KeyedFutures::cancel`]. +/// +/// [`KeyedFutures::cancel`]: crate::client::loadbalance::keyed_futures::KeyedFutures::cancel +pub(crate) struct ConnectingChannel { + inner: Pin> + Send>>, +} + +impl ConnectingChannel { + pub(crate) fn new(fut: BoxFuture, addr: EndpointAddress) -> Self { + Self { + inner: Box::pin(async move { + ReadyChannel { + addr, + inner: fut.await, + } + }), + } + } +} + +impl Future for ConnectingChannel { + type Output = ReadyChannel; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_mut().inner.as_mut().poll(cx) + } +} + +// --------------------------------------------------------------------------- +// ReadyChannel +// --------------------------------------------------------------------------- + +/// A channel that is connected and ready to serve requests. +/// +/// Holds the raw service `S` and delegates [`Service`] calls directly, +/// preserving `S::Future` and `S::Error` with no wrapping or type erasure. +#[derive(Clone)] +pub(crate) struct ReadyChannel { + addr: EndpointAddress, + inner: S, +} + +impl ReadyChannel { + /// Eject this channel (e.g., due to outlier detection). Consumes self. + pub(crate) fn eject(self, config: EjectionConfig, connector: Arc) -> EjectedChannel + where + C: Connector + Send + Sync + 'static, + { + let ejection_timer = tokio::time::sleep(config.timeout); + EjectedChannel { + addr: self.addr, + inner: self.inner, + config, + connector, + ejection_timer, + } + } + + /// Start reconnecting. Consumes self, dropping the old connection. + pub(crate) fn reconnect>( + self, + connector: Arc, + ) -> ConnectingChannel + where + S: Send + 'static, + { + ConnectingChannel::new(connector.connect(&self.addr), self.addr) + } +} + +impl Service for ReadyChannel +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + self.inner.call(req) + } +} + +// --------------------------------------------------------------------------- +// EjectedChannel +// --------------------------------------------------------------------------- + +pin_project! { + /// A channel that has been ejected and is cooling down. + /// + /// The underlying connection is kept alive but cannot serve requests. + /// Implements [`Future`] -- resolves once the ejection timer expires to either: + /// - [`UnejectedChannel::Ready`] if no reconnect is needed + /// - [`UnejectedChannel::Connecting`] if a fresh connection is required + pub(crate) struct EjectedChannel { + addr: EndpointAddress, + inner: S, + config: EjectionConfig, + connector: Arc + Send + Sync>, + #[pin] + ejection_timer: tokio::time::Sleep, + } +} + +impl Future for EjectedChannel { + type Output = UnejectedChannel; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.ejection_timer.poll(cx) { + Poll::Ready(()) => { + if this.config.needs_reconnect { + let fut = this.connector.connect(this.addr); + Poll::Ready(UnejectedChannel::Connecting(ConnectingChannel::new( + fut, + this.addr.clone(), + ))) + } else { + Poll::Ready(UnejectedChannel::Ready(ReadyChannel { + addr: this.addr.clone(), + inner: this.inner.clone(), + })) + } + } + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::loadbalance::keyed_futures::KeyedFutures; + use futures_util::task::noop_waker; + use std::future; + use std::sync::atomic::{AtomicU32, Ordering}; + + #[derive(Clone, Debug)] + struct MockService; + + impl Service<&'static str> for MockService { + type Response = &'static str; + type Error = &'static str; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: &'static str) -> Self::Future { + future::ready(Ok("ok")) + } + } + + struct MockConnector { + connect_count: Arc, + } + + impl MockConnector { + fn new() -> Arc { + Arc::new(Self { + connect_count: Arc::new(AtomicU32::new(0)), + }) + } + } + + impl Connector for MockConnector { + type Service = MockService; + + fn connect(&self, _addr: &EndpointAddress) -> BoxFuture { + self.connect_count.fetch_add(1, Ordering::SeqCst); + Box::pin(future::ready(MockService)) + } + } + + fn test_addr() -> EndpointAddress { + EndpointAddress::new("127.0.0.1", 8080) + } + + fn noop_cx() -> Context<'static> { + Context::from_waker(Box::leak(Box::new(noop_waker()))) + } + + #[tokio::test] + async fn test_idle_to_connecting() { + let connector = MockConnector::new(); + let _connecting = IdleChannel::new(test_addr()).connect(connector.clone()); + assert_eq!(connector.connect_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_connecting_future_yields_ready() { + let connector = MockConnector::new(); + let ready = IdleChannel::new(test_addr()).connect(connector).await; + assert_eq!(ready.addr, test_addr()); + } + + #[tokio::test] + async fn test_ready_service_delegates() { + let connector = MockConnector::new(); + let mut ready = IdleChannel::new(test_addr()).connect(connector).await; + let resp: &str = ready.call("hello").await.unwrap(); + assert_eq!(resp, "ok"); + } + + #[tokio::test] + async fn test_ready_to_connecting_via_reconnect() { + let connector = MockConnector::new(); + let ready = IdleChannel::new(test_addr()) + .connect(connector.clone()) + .await; + let _reconnecting = ready.reconnect(connector.clone()); + assert_eq!(connector.connect_count.load(Ordering::SeqCst), 2); + } + + // --- KeyedFutures integration --- + + #[tokio::test] + async fn test_connecting_in_keyed_futures() { + let (tx, rx) = tokio::sync::oneshot::channel::(); + let connecting = + ConnectingChannel::new(Box::pin(async move { rx.await.unwrap() }), test_addr()); + + let mut set: KeyedFutures> = KeyedFutures::new(); + set.add(test_addr(), connecting).unwrap(); + + assert!(matches!(set.poll_next(&mut noop_cx()), Poll::Pending)); + + tx.send(MockService).unwrap(); + + match set.poll_next(&mut noop_cx()) { + Poll::Ready(Some((addr, _))) => assert_eq!(addr, test_addr()), + _ => panic!("expected Ready"), + } + } + + #[tokio::test] + async fn test_connecting_cancelled_via_keyed_futures() { + let connecting = + ConnectingChannel::new(Box::pin(future::pending::()), test_addr()); + + let mut set: KeyedFutures> = KeyedFutures::new(); + set.add(test_addr(), connecting).unwrap(); + + assert!(matches!(set.poll_next(&mut noop_cx()), Poll::Pending)); + + set.cancel(&test_addr()).unwrap(); + assert!(matches!(set.poll_next(&mut noop_cx()), Poll::Ready(None))); + } + + #[tokio::test(start_paused = true)] + async fn test_ejected_in_keyed_futures_ready() { + let connector = MockConnector::new(); + let ready = IdleChannel::new(test_addr()) + .connect(connector.clone()) + .await; + let ejected = ready.eject( + EjectionConfig { + timeout: Duration::from_secs(5), + needs_reconnect: false, + }, + connector, + ); + + let mut set: KeyedFutures> = + KeyedFutures::new(); + set.add(test_addr(), ejected).unwrap(); + + let (addr, result) = futures_util::future::poll_fn(|cx| set.poll_next(cx)) + .await + .unwrap(); + assert_eq!(addr, test_addr()); + assert!(matches!(result, UnejectedChannel::Ready(_))); + } + + #[tokio::test(start_paused = true)] + async fn test_ejected_in_keyed_futures_needs_reconnect() { + let connector = MockConnector::new(); + let ready = IdleChannel::new(test_addr()) + .connect(connector.clone()) + .await; + let ejected = ready.eject( + EjectionConfig { + timeout: Duration::from_secs(5), + needs_reconnect: true, + }, + connector.clone(), + ); + + let mut set: KeyedFutures> = + KeyedFutures::new(); + set.add(test_addr(), ejected).unwrap(); + + let (addr, result) = futures_util::future::poll_fn(|cx| set.poll_next(cx)) + .await + .unwrap(); + assert_eq!(addr, test_addr()); + assert!(matches!(result, UnejectedChannel::Connecting(_))); + assert_eq!(connector.connect_count.load(Ordering::SeqCst), 2); + } +} diff --git a/tonic-xds/src/client/loadbalance/keyed_futures.rs b/tonic-xds/src/client/loadbalance/keyed_futures.rs new file mode 100644 index 000000000..74319c6f3 --- /dev/null +++ b/tonic-xds/src/client/loadbalance/keyed_futures.rs @@ -0,0 +1,213 @@ +//! [`KeyedFutures`]: a cancellable, keyed set of futures. + +use std::collections::HashMap; +use std::future::Future; +use std::hash::Hash; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_core::Stream; +use futures_util::stream::FuturesUnordered; +use tokio_util::sync::CancellationToken; + +use crate::common::async_util::BoxFuture; + +/// Errors returned by [`KeyedFutures`]. +#[derive(Debug, thiserror::Error)] +pub(crate) enum KeyedFuturesError { + /// A future for this key is already running. + #[error("key {0:?} already exists")] + DuplicateKey(K), + /// No future is running for the given key. + #[error("key {0:?} not found")] + KeyNotFound(K), +} + +/// A cancellable, keyed set of futures. +/// +/// Each future is associated with a key `K` and produces a value `T`. +/// Futures can be cancelled individually by key. [`poll_next`] drives all +/// futures concurrently and yields `(K, T)` when one completes; cancelled +/// futures are silently skipped. +/// +/// Intended for use inside [`tower::Service::poll_ready`] to manage large number of +/// concurrent, cancellable operations (e.g. pending connection attempts). +pub(crate) struct KeyedFutures { + cancellations: HashMap, + futures: FuturesUnordered)>>, +} + +impl KeyedFutures +where + K: Hash + Eq + Clone + Send + std::fmt::Debug + 'static, + T: Send + 'static, +{ + pub(crate) fn new() -> Self { + Self { + cancellations: HashMap::new(), + futures: FuturesUnordered::new(), + } + } + + /// Add a future keyed by `key`. Returns `Err(DuplicateKey)` if a future + /// for this key is already running. + pub(crate) fn add(&mut self, key: K, fut: F) -> Result<(), KeyedFuturesError> + where + F: Future + Send + 'static, + { + if self.cancellations.contains_key(&key) { + return Err(KeyedFuturesError::DuplicateKey(key)); + } + let token = CancellationToken::new(); + self.cancellations.insert(key.clone(), token.clone()); + + self.futures.push(Box::pin(async move { + tokio::select! { + biased; + _ = token.cancelled() => (key, None), + t = fut => (key, Some(t)), + } + })); + Ok(()) + } + + /// Cancel the future for `key`. Returns `Err(KeyNotFound)` if no future + /// is running for the given key. + pub(crate) fn cancel(&mut self, key: &K) -> Result<(), KeyedFuturesError> { + match self.cancellations.remove(key) { + Some(token) => { + token.cancel(); + Ok(()) + } + None => Err(KeyedFuturesError::KeyNotFound(key.clone())), + } + } + + /// Returns the number of futures currently running (including cancelled + /// ones not yet polled to completion). + pub(crate) fn len(&self) -> usize { + self.futures.len() + } + + /// Advance the internal futures. Yields `(K, T)` when a future completes, + /// skipping cancelled futures silently. + /// + /// Returns: + /// - `Poll::Ready(Some((key, output)))` — a future completed successfully. + /// - `Poll::Pending` — no futures ready yet; the waker will be notified. + /// - `Poll::Ready(None)` — all futures have completed or been cancelled. + pub(crate) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match Pin::new(&mut self.futures).poll_next(cx) { + Poll::Ready(Some((key, Some(output)))) => { + self.cancellations.remove(&key); + return Poll::Ready(Some((key, output))); + } + Poll::Ready(Some((_, None))) => continue, // skip cancelled futures + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::task::noop_waker; + + fn noop_cx() -> Context<'static> { + // SAFETY: the waker is never dereferenced; used only to satisfy the + // Context API. FuturesUnordered manages internal task wakeups + // independently of this outer waker. + Context::from_waker(Box::leak(Box::new(noop_waker()))) + } + + #[tokio::test] + async fn test_add_and_poll() { + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + set.add("a", async { 1 }).unwrap(); + set.add("b", async { 2 }).unwrap(); + + let mut results = vec![]; + while let Poll::Ready(Some(item)) = set.poll_next(&mut noop_cx()) { + results.push(item); + } + results.sort(); + assert_eq!(results, vec![("a", 1), ("b", 2)]); + } + + #[tokio::test] + async fn test_poll_pending_then_ready() { + // Use a oneshot channel so the future is pending until we send. + // FuturesUnordered's internal TaskWaker is woken by tx.send(), + // so the next poll_next sees the result without needing yield_now(). + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + let (tx, rx) = tokio::sync::oneshot::channel::(); + set.add("a", async move { rx.await.unwrap() }).unwrap(); + + // Before send: pending. + assert!(matches!(set.poll_next(&mut noop_cx()), Poll::Pending)); + + // Signal the future to complete. + tx.send(42).unwrap(); + + // FuturesUnordered's internal waker was notified; next poll sees result. + assert_eq!(set.poll_next(&mut noop_cx()), Poll::Ready(Some(("a", 42)))); + } + + #[tokio::test] + async fn test_duplicate_key_rejected() { + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + set.add("a", async { 1 }).unwrap(); + assert!(matches!( + set.add("a", async { 2 }), + Err(KeyedFuturesError::DuplicateKey("a")) + )); + } + + #[tokio::test] + async fn test_cancel_skipped_in_poll() { + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + let (tx_a, rx_a) = tokio::sync::oneshot::channel::(); + let (tx_b, rx_b) = tokio::sync::oneshot::channel::(); + + set.add("a", async move { rx_a.await.unwrap() }).unwrap(); + set.add("b", async move { rx_b.await.unwrap() }).unwrap(); + + // Both pending. + assert!(matches!(set.poll_next(&mut noop_cx()), Poll::Pending)); + + // Cancel "a", complete "b". + set.cancel(&"a").unwrap(); + tx_b.send(42).unwrap(); + drop(tx_a); + + // "a" is silently skipped; only "b" is yielded. + assert_eq!(set.poll_next(&mut noop_cx()), Poll::Ready(Some(("b", 42)))); + assert_eq!(set.poll_next(&mut noop_cx()), Poll::Ready(None)); + } + + #[tokio::test] + async fn test_cancel_nonexistent_returns_error() { + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + assert!(matches!( + set.cancel(&"missing"), + Err(KeyedFuturesError::KeyNotFound("missing")) + )); + } + + #[tokio::test] + async fn test_reuse_key_after_completion() { + let mut set: KeyedFutures<&str, u32> = KeyedFutures::new(); + let (tx, rx) = tokio::sync::oneshot::channel::(); + set.add("a", async move { rx.await.unwrap() }).unwrap(); + + tx.send(1).unwrap(); + assert_eq!(set.poll_next(&mut noop_cx()), Poll::Ready(Some(("a", 1)))); + + // Key is free after completion — can be re-added. + set.add("a", async { 2 }).unwrap(); + assert_eq!(set.poll_next(&mut noop_cx()), Poll::Ready(Some(("a", 2)))); + } +} diff --git a/tonic-xds/src/client/loadbalance/mod.rs b/tonic-xds/src/client/loadbalance/mod.rs new file mode 100644 index 000000000..217efb2b3 --- /dev/null +++ b/tonic-xds/src/client/loadbalance/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod channel; +pub(crate) mod channel_state; +pub(crate) mod keyed_futures; diff --git a/tonic-xds/src/client/mod.rs b/tonic-xds/src/client/mod.rs index 8d22f3029..3e02c9b29 100644 --- a/tonic-xds/src/client/mod.rs +++ b/tonic-xds/src/client/mod.rs @@ -3,5 +3,7 @@ pub(crate) mod cluster; pub(crate) mod endpoint; pub(crate) mod lb; #[allow(dead_code)] +pub(crate) mod loadbalance; +#[allow(dead_code)] pub(crate) mod retry; pub(crate) mod route;