From e5ee8e3572d6ca755f84c5cadaa6c122c11892ca Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 25 Jul 2025 13:09:13 +1000 Subject: [PATCH] fix(connlib): wait for sockets to be closed before rebinding (#9996) Our `ThreadedUdpSocket` uses a background thread for the actual socket operation. It merely represents a handle to send and receive from these sockets but not the socket itself. Dropping the handle will shutdown the background thread but that is an asynchronous operation. In order to be sure that we can rebind the same port, we need to wait for the background thread to stop. We thus add a `Drop` implementation for the `ThreadedUdpSocket` that waits for its background thread to disappear before it continues. Resolves: #9992 --- rust/connlib/tunnel/src/sockets.rs | 207 +++++++++++++++++------------ 1 file changed, 123 insertions(+), 84 deletions(-) diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index c07b6d717..21d294d85 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,9 +1,10 @@ use crate::otel; -use anyhow::Result; +use anyhow::{Context as _, Result}; use futures::{SinkExt, StreamExt, ready}; use gat_lending_iterator::LendingIterator; use socket_factory::DatagramOut; use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket}; +use std::time::{Duration, Instant}; use std::{ io, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, @@ -158,6 +159,12 @@ where } struct ThreadedUdpSocket { + thread_name: String, + join_handle: std::thread::JoinHandle<()>, + channels: Option, +} + +struct Channels { outbound_tx: flume::r#async::SendSink<'static, DatagramOut>, inbound_rx: flume::r#async::RecvStream<'static, Result>, } @@ -168,91 +175,59 @@ impl ThreadedUdpSocket { let (inbound_tx, inbound_rx) = flume::bounded(10); let (error_tx, error_rx) = flume::bounded(0); - std::thread::Builder::new() - .name(match preferred_addr { - SocketAddr::V4(_) => "UDP IPv4".to_owned(), - SocketAddr::V6(_) => "UDP IPv6".to_owned(), - }) - .spawn(move || { - let runtime = match tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - { - Ok(r) => r, - Err(e) => { - let _ = error_tx.send(Err(e)); - return; - } - }; - - runtime.block_on(async move { - let mut socket = match listen( - sf, - // Listen on the preferred address, fall back to picking a free port if that doesn't work - &[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)], - ) { - Ok(s) => s, + let thread_name = match preferred_addr { + SocketAddr::V4(_) => "UDP IPv4".to_owned(), + SocketAddr::V6(_) => "UDP IPv6".to_owned(), + }; + let join_handle = + std::thread::Builder::new() + .name(thread_name.clone()) + .spawn(move || { + let runtime = match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(r) => r, Err(e) => { let _ = error_tx.send(Err(e)); return; } }; - let io_error_counter = opentelemetry::global::meter("connlib") - .u64_counter("system.network.errors") - .with_description("Number of IO errors encountered") - .with_unit("{error}") - .build(); - - if let Err(e) = socket.set_buffer_sizes( - socket_factory::SEND_BUFFER_SIZE, - socket_factory::RECV_BUFFER_SIZE, - ) { - tracing::warn!("Failed to set socket buffer sizes: {e}"); + runtime.block_on(async move { + let mut socket = match listen( + sf, + // Listen on the preferred address, fall back to picking a free port if that doesn't work + &[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)], + ) { + Ok(s) => s, + Err(e) => { + let _ = error_tx.send(Err(e)); + return; } + }; - let send = pin!(async { - while let Ok(datagram) = outbound_rx.recv_async().await { - if let Err(e) = socket.send(datagram).await { - if let Some(io) = e.downcast_ref::() { - io_error_counter.add( - 1, - &[ - otel::attr::network_io_direction_transmit(), - otel::attr::network_type_for_addr(preferred_addr), - otel::attr::io_error_type(io), - otel::attr::io_error_code(io), - ], - ); - } + let io_error_counter = opentelemetry::global::meter("connlib") + .u64_counter("system.network.errors") + .with_description("Number of IO errors encountered") + .with_unit("{error}") + .build(); - // We use the inbound_tx channel to send the error back to the main thread. - if inbound_tx.send_async(Err(e)).await.is_err() { - tracing::debug!( - "Channel for inbound datagrams closed; exiting UDP thread" - ); - break; - } - }; - } + if let Err(e) = socket.set_buffer_sizes( + socket_factory::SEND_BUFFER_SIZE, + socket_factory::RECV_BUFFER_SIZE, + ) { + tracing::warn!("Failed to set socket buffer sizes: {e}"); + } - tracing::debug!( - "Channel for outbound datagrams closed; exiting UDP thread" - ); - }); - let receive = pin!(async { - loop { - let result = socket.recv_from().await; - - if let Some(io) = result - .as_ref() - .err() - .and_then(|e| e.downcast_ref::()) - { + let send = pin!(async { + while let Ok(datagram) = outbound_rx.recv_async().await { + if let Err(e) = socket.send(datagram).await { + if let Some(io) = e.downcast_ref::() { io_error_counter.add( 1, &[ - otel::attr::network_io_direction_receive(), + otel::attr::network_io_direction_transmit(), otel::attr::network_type_for_addr(preferred_addr), otel::attr::io_error_type(io), otel::attr::io_error_code(io), @@ -260,37 +235,75 @@ impl ThreadedUdpSocket { ); } - if inbound_tx.send_async(result).await.is_err() { + // We use the inbound_tx channel to send the error back to the main thread. + if inbound_tx.send_async(Err(e)).await.is_err() { tracing::debug!( "Channel for inbound datagrams closed; exiting UDP thread" ); break; } + }; + } + + tracing::debug!("Channel for outbound datagrams closed; exiting UDP thread"); + }); + let receive = pin!(async { + loop { + let result = socket.recv_from().await; + + if let Some(io) = result + .as_ref() + .err() + .and_then(|e| e.downcast_ref::()) + { + io_error_counter.add( + 1, + &[ + otel::attr::network_io_direction_receive(), + otel::attr::network_type_for_addr(preferred_addr), + otel::attr::io_error_type(io), + otel::attr::io_error_code(io), + ], + ); } - }); - let _ = error_tx.send(Ok(())); + if inbound_tx.send_async(result).await.is_err() { + tracing::debug!( + "Channel for inbound datagrams closed; exiting UDP thread" + ); + break; + } + } + }); - futures::future::select(send, receive).await; - }) - })?; + let _ = error_tx.send(Ok(())); + + futures::future::select(send, receive).await; + }) + })?; error_rx.recv().map_err(io::Error::other)??; Ok(Self { - outbound_tx: outbound_tx.into_sink(), - inbound_rx: inbound_rx.into_stream(), + thread_name, + join_handle, + channels: Some(Channels { + outbound_tx: outbound_tx.into_sink(), + inbound_rx: inbound_rx.into_stream(), + }), }) } fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - ready!(self.outbound_tx.poll_ready_unpin(cx)).map_err(|_| UdpSocketThreadStopped)?; + ready!(self.channels_mut()?.outbound_tx.poll_ready_unpin(cx)) + .map_err(|_| UdpSocketThreadStopped)?; Poll::Ready(Ok(())) } fn send(&mut self, datagram: DatagramOut) -> Result<()> { - self.outbound_tx + self.channels_mut()? + .outbound_tx .start_send_unpin(datagram) .map_err(|_| UdpSocketThreadStopped)?; @@ -298,10 +311,36 @@ impl ThreadedUdpSocket { } fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll> { - let iter = ready!(self.inbound_rx.poll_next_unpin(cx)).ok_or(UdpSocketThreadStopped)?; + let iter = ready!(self.channels_mut()?.inbound_rx.poll_next_unpin(cx)) + .ok_or(UdpSocketThreadStopped)?; Poll::Ready(iter) } + + fn channels_mut(&mut self) -> Result<&mut Channels> { + self.channels.as_mut().context("Missing channels") + } +} + +impl Drop for ThreadedUdpSocket { + fn drop(&mut self) { + let start = Instant::now(); + + let _ = self.channels.take(); + + const TIMEOUT: Duration = Duration::from_millis(500); + + while !self.join_handle.is_finished() { + let elapsed = start.elapsed(); + + if elapsed > TIMEOUT { + tracing::debug!(name = %self.thread_name, "Thread did not stop within {TIMEOUT:?}"); + return; + } + } + + tracing::debug!(name = %self.thread_name, duration = ?start.elapsed(), "Background thread stopped"); + } } fn listen(