diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index e280bcc2e..b4f516206 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -54,7 +54,10 @@ async fn try_main() -> Result<()> { let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new)); - tokio::spawn(http_health_check::serve(cli.health_check.health_check_addr)); + tokio::spawn(http_health_check::serve( + cli.health_check.health_check_addr, + || true, + )); match future::try_select(task, ctrl_c) .await diff --git a/rust/http-health-check/src/lib.rs b/rust/http-health-check/src/lib.rs index 796511649..dee477922 100644 --- a/rust/http-health-check/src/lib.rs +++ b/rust/http-health-check/src/lib.rs @@ -1,15 +1,26 @@ +use axum::http::StatusCode; use axum::routing::get; use axum::Router; use std::net::SocketAddr; -/// Runs an HTTP server that always responds to `GET /healthz` with 200 OK. -/// -/// To signal an unhealthy state, simply stop the task. -pub async fn serve(addr: impl Into) -> std::io::Result<()> { +/// Runs an HTTP server that responds to `GET /healthz` with 200 OK or 400 BAD REQUEST, depending on the return value of `is_healthy`. +pub async fn serve( + addr: impl Into, + is_healthy: impl Fn() -> bool + Clone + Send + Sync + 'static, +) -> std::io::Result<()> { let addr = addr.into(); let service = Router::new() - .route("/healthz", get(|| async { "" })) + .route( + "/healthz", + get(move || async move { + if is_healthy() { + StatusCode::OK + } else { + StatusCode::BAD_REQUEST + } + }), + ) .into_make_service(); axum::serve(tokio::net::TcpListener::bind(addr).await?, service).await?; diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 7ab468cc8..af45cf7bd 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -15,6 +15,7 @@ use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::task::Poll; use std::time::{Duration, Instant}; use tracing::{level_filters::LevelFilter, Instrument, Subscriber}; @@ -27,6 +28,8 @@ const STATS_LOG_INTERVAL: Duration = Duration::from_secs(10); const TURN_PORT: u16 = 3478; +const MAX_PARTITION_TIME: Duration = Duration::from_secs(60 * 15); + #[derive(Parser, Debug)] struct Args { /// The public (i.e. internet-reachable) IPv4 address of the relay server. @@ -112,6 +115,13 @@ async fn main() -> Result<()> { args.highest_port, ); + let last_heartbeat_sent = Arc::new(Mutex::new(Option::::None)); + + tokio::spawn(http_health_check::serve( + args.health_check.health_check_addr, + make_is_healthy(last_heartbeat_sent.clone()), + )); + let channel = if let Some(token) = args.token.as_ref() { let base_url = args.api_url.clone(); let stamp_secret = server.auth_secret(); @@ -127,11 +137,7 @@ async fn main() -> Result<()> { None }; - let mut eventloop = Eventloop::new(server, channel, public_addr)?; - - tokio::spawn(http_health_check::serve( - args.health_check.health_check_addr, - )); + let mut eventloop = Eventloop::new(server, channel, public_addr, last_heartbeat_sent)?; tracing::info!(target: "relay", "Listening for incoming traffic on UDP port {TURN_PORT}"); @@ -267,7 +273,7 @@ async fn connect_to_portal( stamp_secret: stamp_secret.expose_secret().to_string(), }, ExponentialBackoffBuilder::default() - .with_max_elapsed_time(None) + .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) .build(), ) .await??; @@ -305,6 +311,8 @@ struct Eventloop { stats_log_interval: tokio::time::Interval, last_num_bytes_relayed: u64, + last_heartbeat_sent: Arc>>, + buffer: [u8; MAX_UDP_SIZE], } @@ -316,6 +324,7 @@ where server: Server, channel: Option>, public_address: IpStack, + last_heartbeat_sent: Arc>>, ) -> Result { let mut sockets = Sockets::new(); @@ -342,6 +351,7 @@ where last_num_bytes_relayed: 0, sockets, buffer: [0u8; MAX_UDP_SIZE], + last_heartbeat_sent, }) } @@ -518,6 +528,7 @@ where } Event::HeartbeatSent => { tracing::debug!(target: "relay", "Heartbeat sent to portal"); + *self.last_heartbeat_sent.lock().unwrap() = Some(Instant::now()); } Event::InboundMessage { msg: (), .. } => {} } @@ -538,6 +549,23 @@ fn fmt_human_throughput(mut throughput: f64) -> String { format!("{throughput:.2} TB/s") } +/// Factory fn for [`is_healthy`]. +fn make_is_healthy( + last_heartbeat_sent: Arc>>, +) -> impl Fn() -> bool + Clone + Send + Sync + 'static { + move || is_healthy(last_heartbeat_sent.clone()) +} + +fn is_healthy(last_heartbeat_sent: Arc>>) -> bool { + let guard = last_heartbeat_sent.lock().unwrap(); + + let Some(last_hearbeat_sent) = *guard else { + return true; // If we are not connected to the portal, we are always healthy. + }; + + last_hearbeat_sent.elapsed() < MAX_PARTITION_TIME +} + #[cfg(test)] mod tests { use super::*; @@ -549,4 +577,30 @@ mod tests { assert_eq!(fmt_human_throughput(955_333_999.0), "955.33 MB/s"); assert_eq!(fmt_human_throughput(100_000_000_000.0), "100.00 GB/s"); } + + // If we are running in standalone mode, we are always healthy. + #[test] + fn given_no_heartbeat_is_healthy() { + let is_healthy = is_healthy(Arc::new(Mutex::new(None))); + + assert!(is_healthy) + } + + #[test] + fn given_heartbeat_in_last_15_min_is_healthy() { + let is_healthy = is_healthy(Arc::new(Mutex::new(Some( + Instant::now() - Duration::from_secs(10), + )))); + + assert!(is_healthy) + } + + #[test] + fn given_last_heartbeat_older_than_15_min_is_not_healthy() { + let is_healthy = is_healthy(Arc::new(Mutex::new(Some( + Instant::now() - Duration::from_secs(60 * 15), + )))); + + assert!(!is_healthy) + } }