diff --git a/rust/apple-client-ffi/src/lib.rs b/rust/apple-client-ffi/src/lib.rs index 11686b515..b1b483cf1 100644 --- a/rust/apple-client-ffi/src/lib.rs +++ b/rust/apple-client-ffi/src/lib.rs @@ -29,6 +29,7 @@ use std::{ time::Duration, }; use tokio::runtime::Runtime; +use tokio::task::JoinHandle; use tracing_subscriber::prelude::*; use tun::Tun; @@ -119,6 +120,7 @@ mod ffi { pub struct WrappedSession { inner: Session, runtime: Option, + event_stream_handler: Option>, telemetry: Telemetry, } @@ -305,7 +307,7 @@ impl WrappedSession { analytics::new_session(device_id, api_url.to_string()); - runtime.spawn(async move { + let event_stream_handler = runtime.spawn(async move { let callback_handler = CallbackHandler { inner: callback_handler, }; @@ -335,6 +337,7 @@ impl WrappedSession { Ok(Self { inner: session, runtime: Some(runtime), + event_stream_handler: Some(event_stream_handler), telemetry, }) } @@ -380,7 +383,19 @@ impl Drop for WrappedSession { return; }; - runtime.block_on(self.telemetry.stop()); + self.inner.stop(); // Instruct the event-loop to shutdown. + runtime.block_on(async { + self.telemetry.stop().await; + + // The `event_stream_handler` task will exit once the stream is drained. + // That only happens once the event-loop has fully shut down. + // Hence, waiting for this task here allows us to wait for the graceful shutdown to complete. + let Some(event_stream_handler) = self.event_stream_handler.take() else { + return; + }; + + let _ = tokio::time::timeout(Duration::from_secs(1), event_stream_handler).await; + }); runtime.shutdown_timeout(Duration::from_secs(1)); // Ensure we don't block forever on a task in the blocking pool. } } diff --git a/rust/client-ffi/src/lib.rs b/rust/client-ffi/src/lib.rs index b2e50a06d..0eb340724 100644 --- a/rust/client-ffi/src/lib.rs +++ b/rust/client-ffi/src/lib.rs @@ -208,7 +208,16 @@ impl Drop for Session { return; }; - runtime.block_on(async { self.telemetry.lock().await.stop_on_crash().await }); + self.inner.stop(); // Instruct the event-loop to shutdown. + + runtime.block_on(async { + self.telemetry.lock().await.stop_on_crash().await; + + // Draining the event-stream allows us to wait for the event-loop to finish its graceful shutdown. + let drain = async { self.events.lock().await.drain().await }; + let _ = tokio::time::timeout(Duration::from_secs(1), drain).await; + }); + runtime.shutdown_timeout(Duration::from_secs(1)); // Ensure we don't block forever on a task in the blocking pool. } } diff --git a/rust/client-shared/src/eventloop.rs b/rust/client-shared/src/eventloop.rs index a37f66f65..5b2a157fc 100644 --- a/rust/client-shared/src/eventloop.rs +++ b/rust/client-shared/src/eventloop.rs @@ -47,7 +47,7 @@ use tun::Tun; static DNS_RESOURCE_RECORDS_CACHE: Mutex> = Mutex::new(BTreeSet::new()); pub struct Eventloop { - tunnel: ClientTunnel, + tunnel: Option, cmd_rx: mpsc::UnboundedReceiver, resource_list_sender: watch::Sender>, @@ -121,7 +121,7 @@ impl Eventloop { )); Self { - tunnel, + tunnel: Some(tunnel), cmd_rx, logged_permission_denied: false, portal_event_rx, @@ -141,45 +141,83 @@ enum CombinedEvent { impl Eventloop { pub async fn run(mut self) -> Result<(), DisconnectError> { loop { - match future::poll_fn(|cx| self.next_event(cx)).await { - CombinedEvent::Command(None) => return Ok(()), - CombinedEvent::Command(Some(cmd)) => { - match self.handle_eventloop_command(cmd).await? { - ControlFlow::Continue(()) => {} - ControlFlow::Break(()) => return Ok(()), - } - } - CombinedEvent::Tunnel(event) => self.handle_tunnel_event(event).await?, - CombinedEvent::Portal(Some(event)) => { - let msg = event.context("Connection to portal failed")?; + match self.tick().await { + Ok(ControlFlow::Continue(())) => continue, + Ok(ControlFlow::Break(())) => { + self.shutdown_tunnel().await?; - self.handle_portal_message(msg).await?; + return Ok(()); } - CombinedEvent::Portal(None) => { - return Err(DisconnectError(anyhow::Error::msg( - "portal task exited unexpectedly", - ))); + Err(e) => { + // Ignore error from shutdown to not obscure the original error. + let _ = self.shutdown_tunnel().await; + + return Err(e); } } } } + async fn tick(&mut self) -> Result, DisconnectError> { + match future::poll_fn(|cx| self.next_event(cx)).await { + CombinedEvent::Command(None) => Ok(ControlFlow::Break(())), + CombinedEvent::Command(Some(cmd)) => { + let cf = self.handle_eventloop_command(cmd).await?; + + Ok(cf) + } + CombinedEvent::Tunnel(event) => { + self.handle_tunnel_event(event).await?; + + Ok(ControlFlow::Continue(())) + } + CombinedEvent::Portal(Some(event)) => { + let msg = event.context("Connection to portal failed")?; + self.handle_portal_message(msg).await?; + + Ok(ControlFlow::Continue(())) + } + CombinedEvent::Portal(None) => Err(DisconnectError(anyhow::Error::msg( + "portal task exited unexpectedly", + ))), + } + } + async fn handle_eventloop_command(&mut self, command: Command) -> Result> { match command { Command::Stop => return Ok(ControlFlow::Break(())), - Command::SetDns(dns) => self.tunnel.state_mut().update_system_resolvers(dns), - Command::SetDisabledResources(resources) => self - .tunnel - .state_mut() - .set_disabled_resources(resources, Instant::now()), + Command::SetDns(dns) => { + let Some(tunnel) = self.tunnel.as_mut() else { + return Ok(ControlFlow::Continue(())); + }; + + tunnel.state_mut().update_system_resolvers(dns); + } + Command::SetDisabledResources(resources) => { + let Some(tunnel) = self.tunnel.as_mut() else { + return Ok(ControlFlow::Continue(())); + }; + + tunnel + .state_mut() + .set_disabled_resources(resources, Instant::now()) + } Command::SetTun(tun) => { - self.tunnel.set_tun(tun); + let Some(tunnel) = self.tunnel.as_mut() else { + return Ok(ControlFlow::Continue(())); + }; + + tunnel.set_tun(tun); } Command::Reset(reason) => { - self.tunnel.reset(&reason); + let Some(tunnel) = self.tunnel.as_mut() else { + return Ok(ControlFlow::Continue(())); + }; + + tunnel.reset(&reason); self.portal_cmd_tx .send(PortalCommand::Connect(PublicKeyParam( - self.tunnel.public_key().to_bytes(), + tunnel.public_key().to_bytes(), ))) .await .context("Failed to connect phoenix-channel")?; @@ -299,17 +337,20 @@ impl Eventloop { } async fn handle_portal_message(&mut self, msg: IngressMessages) -> Result<()> { + let Some(tunnel) = self.tunnel.as_mut() else { + return Ok(()); + }; + match msg { - IngressMessages::ConfigChanged(config) => self - .tunnel - .state_mut() - .update_interface_config(config.interface), + IngressMessages::ConfigChanged(config) => { + tunnel.state_mut().update_interface_config(config.interface) + } IngressMessages::IceCandidates(GatewayIceCandidates { gateway_id, candidates, }) => { for candidate in candidates { - self.tunnel + tunnel .state_mut() .add_ice_candidate(gateway_id, candidate, Instant::now()) } @@ -319,7 +360,7 @@ impl Eventloop { resources, relays, }) => { - let state = self.tunnel.state_mut(); + let state = tunnel.state_mut(); state.update_interface_config(interface); state.set_resources(resources, Instant::now()); @@ -330,19 +371,15 @@ impl Eventloop { ); } IngressMessages::ResourceCreatedOrUpdated(resource) => { - self.tunnel - .state_mut() - .add_resource(resource, Instant::now()); + tunnel.state_mut().add_resource(resource, Instant::now()); } IngressMessages::ResourceDeleted(resource) => { - self.tunnel - .state_mut() - .remove_resource(resource, Instant::now()); + tunnel.state_mut().remove_resource(resource, Instant::now()); } IngressMessages::RelaysPresence(RelaysPresence { disconnected_ids, connected, - }) => self.tunnel.state_mut().update_relays( + }) => tunnel.state_mut().update_relays( BTreeSet::from_iter(disconnected_ids), firezone_tunnel::turn(&connected), Instant::now(), @@ -352,11 +389,9 @@ impl Eventloop { candidates, }) => { for candidate in candidates { - self.tunnel.state_mut().remove_ice_candidate( - gateway_id, - candidate, - Instant::now(), - ) + tunnel + .state_mut() + .remove_ice_candidate(gateway_id, candidate, Instant::now()) } } IngressMessages::FlowCreated(FlowCreated { @@ -370,7 +405,7 @@ impl Eventloop { client_ice_credentials, gateway_ice_credentials, }) => { - match self.tunnel.state_mut().handle_flow_created( + match tunnel.state_mut().handle_flow_created( resource_id, gateway_id, PublicKey::from(gateway_public_key.0), @@ -393,7 +428,7 @@ impl Eventloop { // Re-connecting to the portal means we will receive another `init` and thus new TURN servers. self.portal_cmd_tx .send(PortalCommand::Connect(PublicKeyParam( - self.tunnel.public_key().to_bytes(), + tunnel.public_key().to_bytes(), ))) .await .context("Failed to connect phoenix-channel")?; @@ -408,7 +443,7 @@ impl Eventloop { reason: FailReason::Offline, .. }) => { - self.tunnel.state_mut().set_resource_offline(resource_id); + tunnel.state_mut().set_resource_offline(resource_id); } IngressMessages::FlowCreationFailed(FlowCreationFailed { reason, .. }) => { tracing::debug!("Failed to create flow: {reason:?}") @@ -427,12 +462,26 @@ impl Eventloop { return Poll::Ready(CombinedEvent::Portal(event)); } - if let Poll::Ready(event) = self.tunnel.poll_next_event(cx) { + if let Some(Poll::Ready(event)) = self.tunnel.as_mut().map(|t| t.poll_next_event(cx)) { return Poll::Ready(CombinedEvent::Tunnel(event)); } Poll::Pending } + + async fn shutdown_tunnel(&mut self) -> Result<()> { + let Some(tunnel) = self.tunnel.take() else { + tracing::debug!("Tunnel has already been shut down"); + return Ok(()); + }; + + tunnel + .shutdown() + .await + .context("Failed to shutdown tunnel")?; + + Ok(()) + } } async fn phoenix_channel_event_loop( diff --git a/rust/client-shared/src/lib.rs b/rust/client-shared/src/lib.rs index 6804a7379..a550c7c56 100644 --- a/rust/client-shared/src/lib.rs +++ b/rust/client-shared/src/lib.rs @@ -166,6 +166,12 @@ impl EventStream { pub async fn next(&mut self) -> Option { future::poll_fn(|cx| self.poll_next(cx)).await } + + pub async fn drain(&mut self) -> Vec { + futures::stream::poll_fn(|cx| self.poll_next(cx)) + .collect() + .await + } } impl Drop for Session { diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 47d8dd03e..a1ff98d59 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -280,6 +280,13 @@ impl ClientState { self.node.public_key() } + pub fn shutdown(&mut self, now: Instant) { + tracing::info!("Initiating graceful shutdown"); + + self.peers.clear(); + self.node.close_all(p2p_control::goodbye(), now); + } + /// Updates the NAT for all domains resolved by the stub resolver on the corresponding gateway. /// /// In order to route traffic for DNS resources, the designated gateway needs to set up NAT from diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index f8319d68c..e00badfeb 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -140,6 +140,31 @@ impl ClientTunnel { self.io.reset(); } + /// Shutdown the Client tunnel. + pub fn shutdown(mut self) -> BoxFuture<'static, Result<()>> { + // Initiate shutdown. + self.role_state.shutdown(Instant::now()); + + // Drain all UDP packets that need to be sent. + while let Some(trans) = self.role_state.poll_transmit() { + self.io + .send_network(trans.src, trans.dst, &trans.payload, Ecn::NonEct); + } + + // Return a future that "owns" our IO, polling it until all packets have been flushed. + async move { + tokio::time::timeout( + Duration::from_secs(1), + future::poll_fn(move |cx| self.io.flush(cx)), + ) + .await + .context("Failed to flush within 1s")??; + + Ok(()) + } + .boxed() + } + pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll { for _ in 0..MAX_EVENTLOOP_ITERS { let mut ready = false; diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index f4fb68621..8257c8ed5 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -420,6 +420,9 @@ fn try_main() -> Result<()> { drop(session); + // Drain the event-stream to allow the event-loop to gracefully shutdown. + let _ = tokio::time::timeout(Duration::from_secs(1), event_stream.drain()).await; + result })?;