diff --git a/components/spider-storage/Cargo.toml b/components/spider-storage/Cargo.toml index c4fce35c..c8381ce2 100644 --- a/components/spider-storage/Cargo.toml +++ b/components/spider-storage/Cargo.toml @@ -31,6 +31,7 @@ tokio = { version = "1.50.0", features = [ "sync", "time" ] } +tokio-util = { version = "0.7.18", features = ["rt"] } tracing = { version = "0.1.44", features = ["attributes"] } uuid = { version = "1.19.0", features = ["serde"] } @@ -41,5 +42,4 @@ rand = "0.9.1" serial_test = { version = "3.2.0", features = ["file_locks"] } tabled = "0.20.0" tokio = { version = "1.50.0", features = ["macros", "rt-multi-thread", "sync"] } -tokio-util = { version = "0.7", features = ["rt"] } uuid = { version = "1.19.0", features = ["v4"] } diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index 8bfcadf3..5f2f2fb2 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -78,11 +78,17 @@ pub enum InternalError { #[error("invalid config: {0}")] ReadyQueueInvalidConfig(&'static str), + #[error("invalid config: {0}")] + TaskInstancePoolInvalidConfig(&'static str), + #[error("ready queue channel is closed")] ReadyQueueChannelClosed, #[error(transparent)] WireError(#[from] WireError), + + #[error(transparent)] + Db(#[from] crate::db::DbError), } /// Enums for all errors representing operations that are rejected due to stale cache state. diff --git a/components/spider-storage/src/state.rs b/components/spider-storage/src/state.rs index 4d573778..7610878c 100644 --- a/components/spider-storage/src/state.rs +++ b/components/spider-storage/src/state.rs @@ -1,9 +1,11 @@ pub mod error; pub mod job_cache; +pub mod server; pub mod service; pub use error::StorageServerError; pub use job_cache::JobCache; +pub use server::{ServerRuntime, create_server_runtime}; pub use service::ServiceState; #[cfg(test)] diff --git a/components/spider-storage/src/state/server.rs b/components/spider-storage/src/state/server.rs new file mode 100644 index 00000000..bc9aa829 --- /dev/null +++ b/components/spider-storage/src/state/server.rs @@ -0,0 +1,242 @@ +use std::time::Duration; + +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +use crate::{ + cache::error::{CacheError, InternalError}, + config::DatabaseConfig, + db::{DbStorage, MariaDbStorageConnector, SessionManagement}, + ready_queue::{ReadyQueueConfig, ReadyQueueSender, ReadyQueueSenderHandle, create_ready_queue}, + state::{JobCache, ServiceState, StorageServerError}, + task_instance_pool::{ + TaskInstancePoolConfig, + TaskInstancePoolConnector, + TaskInstancePoolHandle, + create_task_instance_pool, + }, +}; + +/// Per-process storage server runtime. +/// +/// # Type Parameters +/// +/// * `ReadyQueueSenderType` - The ready queue sender type. +/// * `DbConnectorType` - The database connector type. +/// * `TaskInstancePoolConnectorType` - The task instance pool connector type. +pub struct ServerRuntime< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> { + service_state: + ServiceState, + cancellation_token: CancellationToken, + task_instance_pool_join_handle: JoinHandle>, + stop_timeout_sec: u64, +} + +/// Creates a storage server runtime from the database configuration. +/// +/// # Returns +/// +/// A newly created [`ServerRuntime`] on success. +/// +/// # Errors +/// +/// Returns an error if: +/// +/// * Forwards [`MariaDbStorageConnector::connect`]'s return values on failure. +/// * Forwards [`create_ready_queue`]'s return values on failure. +pub async fn create_server_runtime( + db_config: &DatabaseConfig, +) -> Result< + ServerRuntime, + StorageServerError, +> { + let cancellation_token = CancellationToken::new(); + let db = MariaDbStorageConnector::connect(db_config).await?; + let session_id = db.session_id(); + let (ready_queue_sender, ready_queue_receiver) = + create_ready_queue(ReadyQueueConfig::default()).map_err(CacheError::from)?; + let (task_instance_pool_connector, task_instance_pool_join_handle) = create_task_instance_pool( + ready_queue_sender.clone(), + db.clone(), + cancellation_token.clone(), + TaskInstancePoolConfig::default(), + ); + let service_state = ServiceState::new( + db, + session_id, + JobCache::new(), + ready_queue_sender, + ready_queue_receiver, + task_instance_pool_connector, + ); + + Ok(ServerRuntime { + service_state, + cancellation_token, + task_instance_pool_join_handle, + stop_timeout_sec: STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + }) +} + +impl + ServerRuntime +where + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +{ + /// Stops background tasks owned by the runtime. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::Stopping`] if the task instance pool does not stop before timeout. + /// * [`StorageServerError::Cache`] if the task instance pool task fails or cannot be joined. + pub async fn stop_background_tasks(mut self) -> Result<(), StorageServerError> { + self.cancellation_token.cancel(); + let result = tokio::select! { + result = &mut self.task_instance_pool_join_handle => result, + () = tokio::time::sleep(Duration::from_secs(self.stop_timeout_sec)) => { + self.task_instance_pool_join_handle.abort(); + return Err(StorageServerError::Stopping( + "task instance pool stop timed out".to_owned(), + )); + } + }; + let pool_result = result.map_err(|e| { + StorageServerError::Cache(CacheError::Internal( + InternalError::TaskInstancePoolCorrupted(format!("task join error: {e}")), + )) + })?; + pool_result.map_err(|e| StorageServerError::Cache(CacheError::Internal(e))) + } + + /// # Returns + /// + /// A clone of the runtime's [`ServiceState`]. + #[must_use] + pub fn service_state( + &self, + ) -> ServiceState { + self.service_state.clone() + } +} + +const STOP_BACKGROUND_TASKS_TIMEOUT_SEC: u64 = 30; + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::task::JoinHandle; + use tokio_util::sync::CancellationToken; + + use super::ServerRuntime; + use crate::{ + cache::error::InternalError, + db::SessionManagement, + ready_queue::{ReadyQueueConfig, ReadyQueueSenderHandle, create_ready_queue}, + state::{ + JobCache, + ServiceState, + StorageServerError, + test_utils::{MockDbConnector, MockTaskInstancePoolConnector}, + }, + }; + + type TestServerRuntime = + ServerRuntime; + + fn create_test_server_runtime( + cancellation_token: CancellationToken, + task: JoinHandle>, + stop_timeout_sec: u64, + ) -> TestServerRuntime { + let db = MockDbConnector::default(); + let session_id = db.session_id(); + let (sender, receiver) = + create_ready_queue(ReadyQueueConfig::default()).expect("ready queue creation"); + let service_state = ServiceState::new( + db, + session_id, + JobCache::new(), + sender, + receiver, + MockTaskInstancePoolConnector, + ); + + ServerRuntime { + service_state, + cancellation_token, + task_instance_pool_join_handle: task, + stop_timeout_sec, + } + } + + #[tokio::test] + async fn stop_background_tasks_cancels_and_joins_task() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task_cancellation_token = cancellation_token.clone(); + let task: JoinHandle> = tokio::spawn(async move { + task_cancellation_token.cancelled().await; + Ok(()) + }); + + let runtime = create_test_server_runtime( + cancellation_token, + task, + super::STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + ); + runtime + .stop_background_tasks() + .await + .expect("stop_background_tasks should succeed"); + Ok(()) + } + + #[tokio::test] + async fn stop_background_tasks_returns_stopping_on_timeout() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task: JoinHandle> = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(10)).await; + Ok(()) + }); + + let runtime = create_test_server_runtime(cancellation_token, task, 0); + let result = runtime.stop_background_tasks().await; + + assert!( + matches!(result, Err(StorageServerError::Stopping(_))), + "timeout should return Stopping" + ); + Ok(()) + } + + #[tokio::test] + async fn stop_background_tasks_returns_cache_error_on_pool_error() -> anyhow::Result<()> { + let cancellation_token = CancellationToken::new(); + let task: JoinHandle> = tokio::spawn(async move { + Err(InternalError::TaskInstancePoolCorrupted( + "test failure".to_owned(), + )) + }); + + let runtime = create_test_server_runtime( + cancellation_token, + task, + super::STOP_BACKGROUND_TASKS_TIMEOUT_SEC, + ); + let result = runtime.stop_background_tasks().await; + + assert!( + matches!(result, Err(StorageServerError::Cache(_))), + "pool task failure should return Cache error" + ); + Ok(()) + } +} diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index ace45ce6..24c914b7 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -7,7 +7,7 @@ //! the task so a new instance can be scheduled, while the original instance remains live until it //! completes or is force-removed. //! * **Dead-execution-manager recovery**: During each GC cycle, the pool queries the -//! [`ExecutionManagerLivenessStore`] to detect dead execution managers, force-removes their +//! [`ExecutionManagerLivenessManagement`] to detect dead execution managers, force-removes their //! instances from the task control blocks, and re-enqueues the corresponding tasks. //! //! Internally, the pool runs as a single-owner coroutine: a tokio task owns the mutable state @@ -24,7 +24,11 @@ use std::{ use async_trait::async_trait; use spider_core::types::id::{ExecutionManagerId, JobId, ResourceGroupId, TaskInstanceId}; -use tokio::sync::mpsc; +use tokio::{ + sync::{mpsc, mpsc::error::TryRecvError}, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; use crate::{ cache::{ @@ -32,6 +36,7 @@ use crate::{ error::InternalError, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, + db::ExecutionManagerLivenessManagement, ready_queue::ReadyQueueSender, }; @@ -46,58 +51,6 @@ pub struct TaskInstanceMetadata { pub soft_timeout_ddl: Option, } -/// Store for tracking execution manager liveness state. -/// -/// Implementations persist execution manager heartbeat state durably and provide an atomic -/// operation to detect and mark disconnected execution managers as dead. -#[async_trait] -pub trait ExecutionManagerLivenessStore: Clone + Send + Sync { - /// Checks whether the execution manager with the given ID is alive. - /// - /// # Parameters - /// - /// * `id` - The execution manager ID to check. - /// - /// # Returns - /// - /// Whether the execution manager is alive on success. - /// - /// # Errors - /// - /// Returns an error if: - /// - /// * Forwards the underlying store's return values on failure. - async fn is_execution_manager_alive( - &self, - id: &ExecutionManagerId, - ) -> Result; - - /// Returns the IDs of execution managers whose last heartbeat is before `stale_before`, after - /// marking them dead. - /// - /// This operation is atomic: once an execution manager is returned by this method, it will not - /// be returned again in subsequent calls. - /// - /// # Parameters - /// - /// * `stale_before` - The cutoff time; execution managers with no heartbeat after this time are - /// considered dead. - /// - /// # Returns - /// - /// A vector of dead execution manager IDs on success. - /// - /// # Errors - /// - /// Returns an error if: - /// - /// * Forwards the underlying store's return values on failure. - async fn get_dead_execution_managers( - &self, - stale_before: SystemTime, - ) -> Result, InternalError>; -} - /// Connector for creating and registering task instances in the task instance pool. /// /// This trait is invoked by the cache layer to allocate task instance IDs and register newly @@ -167,53 +120,110 @@ pub struct TaskInstancePoolHandle { sender: mpsc::Sender, } -impl TaskInstancePoolHandle { - /// Creates a new task instance pool and returns a handle to it. - /// - /// # Type Parameters +/// Configuration for a task instance pool actor. +/// +/// Controls GC timing, channel buffering, and execution manager staleness detection. +#[derive(Debug, Clone, Copy)] +pub struct TaskInstancePoolConfig { + /// Seconds without a heartbeat after which an execution manager is considered stale. + pub execution_manager_stale_after_sec: u64, + /// Interval in seconds between GC cycles that check for dead execution managers. + pub gc_interval: u64, + /// Maximum number of pending registration messages in the pool channel. + pub channel_size: usize, +} + +impl TaskInstancePoolConfig { + /// Creates a new [`TaskInstancePoolConfig`] with validation. /// - /// * `ReadyQueueSenderType` - The ready queue sender implementation for re-enqueue operations. - /// * `LivenessStoreType` - The execution manager liveness store implementation. + /// # Errors /// - /// # Returns + /// Returns an error if: /// - /// A [`TaskInstancePoolHandle`] connected to the newly spawned pool coroutine. - #[must_use] - pub fn create< - ReadyQueueSenderType: ReadyQueueSender + 'static, - LivenessStoreType: ExecutionManagerLivenessStore + 'static, - >( - ready_queue_sender: ReadyQueueSenderType, - execution_manager_liveness_store: LivenessStoreType, - execution_manager_stale_cutoff: Duration, - gc_interval: Duration, + /// * `execution_manager_stale_after_sec` is zero. + /// * `gc_interval` is zero. + /// * `channel_size` is zero. + pub const fn new( + execution_manager_stale_after_sec: u64, + gc_interval: u64, channel_size: usize, - ) -> Self { - let next_task_instance_id = Arc::new(AtomicU64::new(1)); - let (sender, receiver) = mpsc::channel(channel_size); - - let pool = TaskInstancePool { - ready_queue_sender, - execution_manager_liveness_store, - execution_manager_stale_cutoff, - instances: Vec::new(), - execution_manager_pool: HashSet::new(), - receiver, - }; - tokio::spawn(async move { - match pool.run(gc_interval).await { - Ok(()) => {} - Err(_e) => todo!("log this error and terminate the storage service"), - } - }); + ) -> Result { + if execution_manager_stale_after_sec == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "execution_manager_stale_after_sec must be greater than zero", + )); + } + if gc_interval == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "gc_interval must be greater than zero", + )); + } + if channel_size == 0 { + return Err(InternalError::TaskInstancePoolInvalidConfig( + "channel_size must be greater than zero", + )); + } + Ok(Self { + execution_manager_stale_after_sec, + gc_interval, + channel_size, + }) + } +} +impl Default for TaskInstancePoolConfig { + fn default() -> Self { Self { - next_task_instance_id, - sender, + execution_manager_stale_after_sec: 60, + gc_interval: 30, + channel_size: 128, } } } +/// Creates a task instance pool and returns the handle plus the spawned actor task. +/// +/// # Type Parameters +/// +/// * `ReadyQueueSenderType` - The ready queue sender implementation for re-enqueue operations. +/// * `LivenessStoreType` - The execution manager liveness store implementation. +/// +/// # Returns +/// +/// A [`TaskInstancePoolHandle`] and the spawned actor's [`JoinHandle`]. +pub fn create_task_instance_pool< + ReadyQueueSenderType: ReadyQueueSender + 'static, + LivenessStoreType: ExecutionManagerLivenessManagement + 'static, +>( + ready_queue_sender: ReadyQueueSenderType, + execution_manager_liveness_store: LivenessStoreType, + cancellation_token: CancellationToken, + config: TaskInstancePoolConfig, +) -> ( + TaskInstancePoolHandle, + JoinHandle>, +) { + let next_task_instance_id = Arc::new(AtomicU64::new(1)); + let (sender, receiver) = mpsc::channel(config.channel_size); + + let pool = TaskInstancePool { + ready_queue_sender, + execution_manager_liveness_store, + execution_manager_stale_after_sec: config.execution_manager_stale_after_sec, + instances: Vec::new(), + execution_manager_pool: HashSet::new(), + receiver, + }; + let pool_join_handle = + tokio::spawn(async move { pool.run(cancellation_token, config.gc_interval).await }); + let handle = TaskInstancePoolHandle { + next_task_instance_id, + sender, + }; + + (handle, pool_join_handle) +} + #[async_trait] impl TaskInstancePoolConnector for TaskInstancePoolHandle { fn get_next_available_task_instance_id(&self) -> TaskInstanceId { @@ -314,17 +324,17 @@ enum PoolMessage { /// * `LivenessStoreType` - The execution manager liveness store implementation. struct TaskInstancePool< ReadyQueueSenderType: ReadyQueueSender, - LivenessStoreType: ExecutionManagerLivenessStore, + LivenessStoreType: ExecutionManagerLivenessManagement, > { ready_queue_sender: ReadyQueueSenderType, execution_manager_liveness_store: LivenessStoreType, execution_manager_pool: HashSet, - execution_manager_stale_cutoff: Duration, + execution_manager_stale_after_sec: u64, instances: Vec, receiver: mpsc::Receiver, } -impl +impl TaskInstancePool { /// Runs the coroutine loop, processing messages and GC timer ticks. @@ -335,13 +345,21 @@ impl Result<(), InternalError> { - let mut gc_interval = tokio::time::interval(gc_interval); + async fn run( + mut self, + cancellation_token: CancellationToken, + gc_interval: u64, + ) -> Result<(), InternalError> { + let mut gc_interval = tokio::time::interval(Duration::from_secs(gc_interval)); // The first tick completes immediately; skip it so we don't GC right at startup. gc_interval.tick().await; loop { tokio::select! { + () = cancellation_token.cancelled() => { + self.drain_received_messages().await?; + return Ok(()); + } message = self.receiver.recv() => { let Some(message) = message else { // TODO: log this exit @@ -356,21 +374,37 @@ impl Result<(), InternalError> { + loop { + match self.receiver.try_recv() { + Ok(message) => self.handle_message(message).await?, + Err(TryRecvError::Empty | TryRecvError::Disconnected) => return Ok(()), + } + } + } + /// Handles a single pool message. /// /// # Errors /// /// Returns an error if: /// - /// * Forwards [`ExecutionManagerLivenessStore::is_execution_manager_alive`]'s return values on - /// failure. + /// * Forwards [`ExecutionManagerLivenessManagement::is_execution_manager_alive`]'s return + /// values on failure. /// * Forwards [`Self::re_enqueue_task`]'s return values on failure. #[allow(clippy::set_contains_or_insert)] async fn handle_message(&mut self, message: PoolMessage) -> Result<(), InternalError> { match message { PoolMessage::Register { tcb, metadata } => { - let em_id = &metadata.execution_manager_id; - if !self.execution_manager_pool.contains(em_id) { + let em_id = metadata.execution_manager_id; + if !self.execution_manager_pool.contains(&em_id) { if !self .execution_manager_liveness_store .is_execution_manager_alive(em_id) @@ -390,7 +424,7 @@ impl Result<(), InternalError> { let dead_em_ids: Vec = self .execution_manager_liveness_store - .get_dead_execution_managers( - gc_started_at - .checked_sub(self.execution_manager_stale_cutoff) - .unwrap_or(SystemTime::UNIX_EPOCH), - ) + .get_dead_execution_managers(self.execution_manager_stale_after_sec) .await?; for execution_manager_id in &dead_em_ids { @@ -520,7 +550,10 @@ impl>>, alive_call_count: Arc, } #[async_trait] - impl ExecutionManagerLivenessStore for MockExecutionManagerLivenessStore { + impl ExecutionManagerLivenessManagement for MockExecutionManagerLivenessManagement { + async fn register_execution_manager( + &self, + _ip_address: IpAddr, + ) -> Result { + unimplemented!("not needed by pool tests") + } + + async fn update_execution_manager_heartbeat( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result<(), DbError> { + unimplemented!("not needed by pool tests") + } + async fn is_execution_manager_alive( &self, - _id: &ExecutionManagerId, - ) -> Result { + _execution_manager_id: ExecutionManagerId, + ) -> Result { self.alive_call_count .fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(true) @@ -565,8 +612,8 @@ mod tests { async fn get_dead_execution_managers( &self, - _stale_before: SystemTime, - ) -> Result, InternalError> { + _stale_after_sec: u64, + ) -> Result, DbError> { Ok(self.dead_execution_managers.lock().await.clone()) } } @@ -625,23 +672,37 @@ mod tests { } } - /// A [`ExecutionManagerLivenessStore`] where all EMs are reported as dead. + /// A [`ExecutionManagerLivenessManagement`] where all EMs are reported as dead. #[derive(Clone, Default)] struct RejectAllLivenessStore; #[async_trait] - impl ExecutionManagerLivenessStore for RejectAllLivenessStore { + impl ExecutionManagerLivenessManagement for RejectAllLivenessStore { + async fn register_execution_manager( + &self, + _ip_address: IpAddr, + ) -> Result { + unimplemented!("not needed by pool tests") + } + + async fn update_execution_manager_heartbeat( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result<(), DbError> { + unimplemented!("not needed by pool tests") + } + async fn is_execution_manager_alive( &self, - _id: &ExecutionManagerId, - ) -> Result { + _execution_manager_id: ExecutionManagerId, + ) -> Result { Ok(false) } async fn get_dead_execution_managers( &self, - _stale_before: SystemTime, - ) -> Result, InternalError> { + _stale_after_sec: u64, + ) -> Result, DbError> { Ok(Vec::new()) } } @@ -701,15 +762,15 @@ mod tests { /// sender is dropped immediately. fn build_test_pool( ready_queue_sender: MockReadyQueueSender, - liveness_store: MockExecutionManagerLivenessStore, + liveness_store: MockExecutionManagerLivenessManagement, execution_manager_stale_cutoff: Duration, - ) -> TaskInstancePool { + ) -> TaskInstancePool { let (_sender, receiver) = mpsc::channel(1); TaskInstancePool { ready_queue_sender, execution_manager_liveness_store: liveness_store, execution_manager_pool: HashSet::new(), - execution_manager_stale_cutoff, + execution_manager_stale_after_sec: execution_manager_stale_cutoff.as_secs(), instances: Vec::new(), receiver, } @@ -722,7 +783,7 @@ mod tests { /// /// The job ID assigned to the task, so callers can match it against re-enqueue messages. async fn register_task_in_pool( - pool: &mut TaskInstancePool, + pool: &mut TaskInstancePool, tcb: &SharedTaskControlBlock, task_id: TaskId, task_instance_id: TaskInstanceId, @@ -752,12 +813,16 @@ mod tests { #[tokio::test] async fn dead_execution_manager_registration_triggers_recovery() { let ready_queue_sender = MockReadyQueueSender::default(); - let pool = TaskInstancePoolHandle::create( + let cancellation_token = CancellationToken::new(); + let (pool, pool_join_handle) = create_task_instance_pool( ready_queue_sender.clone(), RejectAllLivenessStore, - Duration::from_mins(1), - Duration::from_mins(1), - DEFAULT_CHANNEL_SIZE, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: 60, + channel_size: DEFAULT_CHANNEL_SIZE, + }, ); let tcb = build_single_task_tcb().await; let task_instance_id = 1; @@ -775,7 +840,7 @@ mod tests { pool.register_task_instance(tcb.clone(), metadata) .await - .unwrap(); + .expect("registration should be sent"); // Give the pool coroutine time to process the message. tokio::time::sleep(Duration::from_millis(100)).await; @@ -785,18 +850,28 @@ mod tests { messages.contains(&ReadyMessage::Task(job_id, 0)), "task should be re-enqueued for dead EM, got: {messages:?}" ); + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); } #[tokio::test] async fn valid_em_is_cached_and_subsequent_registrations_skip_verify() { let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); - let pool = TaskInstancePoolHandle::create( + let liveness_store = MockExecutionManagerLivenessManagement::default(); + let cancellation_token = CancellationToken::new(); + let (pool, pool_join_handle) = create_task_instance_pool( ready_queue_sender, liveness_store.clone(), - Duration::from_mins(1), - Duration::from_mins(1), - DEFAULT_CHANNEL_SIZE, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: 60, + channel_size: DEFAULT_CHANNEL_SIZE, + }, ); let execution_manager_id = ExecutionManagerId::new(); @@ -807,7 +882,9 @@ mod tests { execution_manager_id, SystemTime::now(), ); - pool.register_task_instance(tcb1, metadata1).await.unwrap(); + pool.register_task_instance(tcb1, metadata1) + .await + .expect("first registration should succeed"); let tcb2 = build_single_task_tcb().await; let metadata2 = make_task_instance_metadata( @@ -816,7 +893,9 @@ mod tests { execution_manager_id, SystemTime::now(), ); - pool.register_task_instance(tcb2, metadata2).await.unwrap(); + pool.register_task_instance(tcb2, metadata2) + .await + .expect("second registration should succeed"); // Give the pool coroutine time to process both messages. tokio::time::sleep(Duration::from_millis(100)).await; @@ -828,6 +907,136 @@ mod tests { 1, "liveness store should be called exactly once for two registrations with the same EM" ); + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + } + + #[tokio::test] + async fn spawned_pool_exits_when_cancelled() { + let cancellation_token = CancellationToken::new(); + let (_pool, pool_join_handle) = create_task_instance_pool( + MockReadyQueueSender::default(), + MockExecutionManagerLivenessManagement::default(), + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: 60, + channel_size: DEFAULT_CHANNEL_SIZE, + }, + ); + + cancellation_token.cancel(); + + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + } + + #[tokio::test] + async fn spawned_pool_processes_registration_before_shutdown() { + let ready_queue_sender = MockReadyQueueSender::default(); + let cancellation_token = CancellationToken::new(); + let (pool, pool_join_handle) = create_task_instance_pool( + ready_queue_sender.clone(), + RejectAllLivenessStore, + cancellation_token.clone(), + TaskInstancePoolConfig { + execution_manager_stale_after_sec: 60, + gc_interval: 60, + channel_size: DEFAULT_CHANNEL_SIZE, + }, + ); + let tcb = build_single_task_tcb().await; + let task_instance_id = 1; + let _ = tcb + .register_task_instance(task_instance_id) + .await + .expect("TCB registration should succeed"); + let metadata = make_task_instance_metadata( + TaskId::Index(0), + task_instance_id, + ExecutionManagerId::new(), + SystemTime::now(), + ); + let job_id = metadata.job_id; + + pool.register_task_instance(tcb, metadata) + .await + .expect("registration should be sent"); + tokio::time::sleep(Duration::from_millis(100)).await; + cancellation_token.cancel(); + tokio::time::timeout(Duration::from_secs(1), pool_join_handle) + .await + .expect("pool task should exit before timeout") + .expect("pool task should join successfully") + .expect("pool task should return success"); + + let messages = ready_queue_sender.sent_messages.lock().await.clone(); + assert!( + messages.contains(&ReadyMessage::Task(job_id, 0)), + "registration should be processed before shutdown, got: {messages:?}" + ); + } + + #[tokio::test] + async fn run_drains_queued_registrations_when_already_cancelled() { + let ready_queue_sender = MockReadyQueueSender::default(); + let cancellation_token = CancellationToken::new(); + cancellation_token.cancel(); + let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + let mut expected_messages: Vec = Vec::new(); + + for task_index in 0..3 { + let tcb = build_single_task_tcb().await; + let task_instance_id = task_index as TaskInstanceId + 1; + let _ = tcb + .register_task_instance(task_instance_id) + .await + .expect("TCB registration should succeed"); + let metadata = make_task_instance_metadata( + TaskId::Index(task_index), + task_instance_id, + ExecutionManagerId::new(), + SystemTime::now(), + ); + expected_messages.push(ReadyMessage::Task(metadata.job_id, task_index)); + sender + .send(PoolMessage::Register { + tcb: Tcb::Task(tcb), + metadata, + }) + .await + .expect("pool message should be queued"); + } + + drop(sender); + let pool = TaskInstancePool { + ready_queue_sender: ready_queue_sender.clone(), + execution_manager_liveness_store: RejectAllLivenessStore, + execution_manager_pool: HashSet::new(), + execution_manager_stale_after_sec: 60, + instances: Vec::new(), + receiver, + }; + + pool.run(cancellation_token, 60) + .await + .expect("pool should stop cleanly"); + + let messages = ready_queue_sender.sent_messages.lock().await.clone(); + assert_eq!(messages.len(), expected_messages.len(), "got: {messages:?}"); + for expected in &expected_messages { + assert!( + messages.contains(expected), + "missing drained registration {expected:?}, got: {messages:?}" + ); + } } #[tokio::test] @@ -835,7 +1044,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store, @@ -879,7 +1088,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store, @@ -937,7 +1146,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(), @@ -995,7 +1204,7 @@ mod tests { const NUM_TASKS: usize = 10; let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(), @@ -1053,7 +1262,7 @@ mod tests { // index 3: dead EM, terminated -> removed, no re-enqueue (terminal wins) // index 4: dead EM, on-going -> removed, re-enqueued let ready_queue_sender = MockReadyQueueSender::default(); - let liveness_store = MockExecutionManagerLivenessStore::default(); + let liveness_store = MockExecutionManagerLivenessManagement::default(); let mut pool = build_test_pool( ready_queue_sender.clone(), liveness_store.clone(),