diff --git a/iris-mpc-common/src/helpers/shutdown_handler.rs b/iris-mpc-common/src/helpers/shutdown_handler.rs index e596335d7..24513bac6 100644 --- a/iris-mpc-common/src/helpers/shutdown_handler.rs +++ b/iris-mpc-common/src/helpers/shutdown_handler.rs @@ -29,6 +29,10 @@ impl ShutdownHandler { self.shutdown_received.load(Ordering::Relaxed) } + pub fn trigger_manual_shutdown(&self) { + self.shutdown_received.store(true, Ordering::Relaxed); + } + pub async fn wait_for_shutdown_signal(&self) { let shutdown_flag = self.shutdown_received.clone(); tokio::spawn(async move { diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 2b1ed4f54..dc7b26566 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -663,7 +663,9 @@ async fn main() -> eyre::Result<()> { } async fn server_main(config: Config) -> eyre::Result<()> { - let shutdown_handler = ShutdownHandler::new(config.shutdown_last_results_sync_timeout_secs); + let shutdown_handler = Arc::new(ShutdownHandler::new( + config.shutdown_last_results_sync_timeout_secs, + )); shutdown_handler.wait_for_shutdown_signal().await; // Load batch_size config @@ -777,27 +779,47 @@ async fn server_main(config: Config) -> eyre::Result<()> { let is_ready_flag = Arc::new(AtomicBool::new(false)); let is_ready_flag_cloned = Arc::clone(&is_ready_flag); - #[derive(Serialize, Deserialize)] + #[derive(Debug, Serialize, Deserialize, Clone)] struct ReadyProbeResponse { image_name: String, uuid: String, + shutdown: bool, } + let health_shutdown_handler = Arc::clone(&shutdown_handler); + let _health_check_abort = background_tasks.spawn({ let uuid = uuid::Uuid::new_v4().to_string(); let ready_probe_response = ReadyProbeResponse { image_name: config.image_name.clone(), - uuid, + shutdown: false, + uuid: uuid.clone(), + }; + let ready_probe_response_shutdown = ReadyProbeResponse { + image_name: config.image_name.clone(), + shutdown: true, + uuid: uuid.clone(), }; let serialized_response = serde_json::to_string(&ready_probe_response) .expect("Serialization to JSON to probe response failed"); + let serialized_response_shutdown = serde_json::to_string(&ready_probe_response_shutdown) + .expect("Serialization to JSON to probe response failed"); tracing::info!("Healthcheck probe response: {}", serialized_response); async move { // Generate a random UUID for each run. let app = Router::new() .route( "/health", - get(move || async move { serialized_response.clone() }), + get(move || { + let shutdown_handler_clone = Arc::clone(&health_shutdown_handler); + async move { + if shutdown_handler_clone.is_shutting_down() { + serialized_response_shutdown.clone() + } else { + serialized_response.clone() + } + } + }), ) .route( "/ready", @@ -831,6 +853,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { let mut heartbeat_tx = Some(heartbeat_tx); let all_nodes = config.node_hostnames.clone(); let image_name = config.image_name.clone(); + let heartbeat_shutdown_handler = Arc::clone(&shutdown_handler); let _heartbeat = background_tasks.spawn(async move { let next_node = &all_nodes[(config.party_id + 1) % 3]; let prev_node = &all_nodes[(config.party_id + 2) % 3]; @@ -863,6 +886,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { .json::() .await .expect("Deserialization of probe response failed"); + if probe_response.shutdown { + tracing::error!( + "Node {} has starting graceful shutdown. Therefore starting graceful \ + shutdown", + host + ); + heartbeat_shutdown_handler.trigger_manual_shutdown(); + } if probe_response.image_name != image_name { // Do not create a panic as we still can continue to process before its // updated