diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index c07b6d717..21d294d85 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,9 +1,10 @@ use crate::otel; -use anyhow::Result; +use anyhow::{Context as _, Result}; use futures::{SinkExt, StreamExt, ready}; use gat_lending_iterator::LendingIterator; use socket_factory::DatagramOut; use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket}; +use std::time::{Duration, Instant}; use std::{ io, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, @@ -158,6 +159,12 @@ where } struct ThreadedUdpSocket { + thread_name: String, + join_handle: std::thread::JoinHandle<()>, + channels: Option, +} + +struct Channels { outbound_tx: flume::r#async::SendSink<'static, DatagramOut>, inbound_rx: flume::r#async::RecvStream<'static, Result>, } @@ -168,91 +175,59 @@ impl ThreadedUdpSocket { let (inbound_tx, inbound_rx) = flume::bounded(10); let (error_tx, error_rx) = flume::bounded(0); - std::thread::Builder::new() - .name(match preferred_addr { - SocketAddr::V4(_) => "UDP IPv4".to_owned(), - SocketAddr::V6(_) => "UDP IPv6".to_owned(), - }) - .spawn(move || { - let runtime = match tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - { - Ok(r) => r, - Err(e) => { - let _ = error_tx.send(Err(e)); - return; - } - }; - - runtime.block_on(async move { - let mut socket = match listen( - sf, - // Listen on the preferred address, fall back to picking a free port if that doesn't work - &[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)], - ) { - Ok(s) => s, + let thread_name = match preferred_addr { + SocketAddr::V4(_) => "UDP IPv4".to_owned(), + SocketAddr::V6(_) => "UDP IPv6".to_owned(), + }; + let join_handle = + std::thread::Builder::new() + .name(thread_name.clone()) + .spawn(move || { + let runtime = match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(r) => r, Err(e) => { let _ = error_tx.send(Err(e)); return; } }; - let io_error_counter = opentelemetry::global::meter("connlib") - .u64_counter("system.network.errors") - .with_description("Number of IO errors encountered") - .with_unit("{error}") - .build(); - - if let Err(e) = socket.set_buffer_sizes( - socket_factory::SEND_BUFFER_SIZE, - socket_factory::RECV_BUFFER_SIZE, - ) { - tracing::warn!("Failed to set socket buffer sizes: {e}"); + runtime.block_on(async move { + let mut socket = match listen( + sf, + // Listen on the preferred address, fall back to picking a free port if that doesn't work + &[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)], + ) { + Ok(s) => s, + Err(e) => { + let _ = error_tx.send(Err(e)); + return; } + }; - let send = pin!(async { - while let Ok(datagram) = outbound_rx.recv_async().await { - if let Err(e) = socket.send(datagram).await { - if let Some(io) = e.downcast_ref::() { - io_error_counter.add( - 1, - &[ - otel::attr::network_io_direction_transmit(), - otel::attr::network_type_for_addr(preferred_addr), - otel::attr::io_error_type(io), - otel::attr::io_error_code(io), - ], - ); - } + let io_error_counter = opentelemetry::global::meter("connlib") + .u64_counter("system.network.errors") + .with_description("Number of IO errors encountered") + .with_unit("{error}") + .build(); - // We use the inbound_tx channel to send the error back to the main thread. - if inbound_tx.send_async(Err(e)).await.is_err() { - tracing::debug!( - "Channel for inbound datagrams closed; exiting UDP thread" - ); - break; - } - }; - } + if let Err(e) = socket.set_buffer_sizes( + socket_factory::SEND_BUFFER_SIZE, + socket_factory::RECV_BUFFER_SIZE, + ) { + tracing::warn!("Failed to set socket buffer sizes: {e}"); + } - tracing::debug!( - "Channel for outbound datagrams closed; exiting UDP thread" - ); - }); - let receive = pin!(async { - loop { - let result = socket.recv_from().await; - - if let Some(io) = result - .as_ref() - .err() - .and_then(|e| e.downcast_ref::()) - { + let send = pin!(async { + while let Ok(datagram) = outbound_rx.recv_async().await { + if let Err(e) = socket.send(datagram).await { + if let Some(io) = e.downcast_ref::() { io_error_counter.add( 1, &[ - otel::attr::network_io_direction_receive(), + otel::attr::network_io_direction_transmit(), otel::attr::network_type_for_addr(preferred_addr), otel::attr::io_error_type(io), otel::attr::io_error_code(io), @@ -260,37 +235,75 @@ impl ThreadedUdpSocket { ); } - if inbound_tx.send_async(result).await.is_err() { + // We use the inbound_tx channel to send the error back to the main thread. + if inbound_tx.send_async(Err(e)).await.is_err() { tracing::debug!( "Channel for inbound datagrams closed; exiting UDP thread" ); break; } + }; + } + + tracing::debug!("Channel for outbound datagrams closed; exiting UDP thread"); + }); + let receive = pin!(async { + loop { + let result = socket.recv_from().await; + + if let Some(io) = result + .as_ref() + .err() + .and_then(|e| e.downcast_ref::()) + { + io_error_counter.add( + 1, + &[ + otel::attr::network_io_direction_receive(), + otel::attr::network_type_for_addr(preferred_addr), + otel::attr::io_error_type(io), + otel::attr::io_error_code(io), + ], + ); } - }); - let _ = error_tx.send(Ok(())); + if inbound_tx.send_async(result).await.is_err() { + tracing::debug!( + "Channel for inbound datagrams closed; exiting UDP thread" + ); + break; + } + } + }); - futures::future::select(send, receive).await; - }) - })?; + let _ = error_tx.send(Ok(())); + + futures::future::select(send, receive).await; + }) + })?; error_rx.recv().map_err(io::Error::other)??; Ok(Self { - outbound_tx: outbound_tx.into_sink(), - inbound_rx: inbound_rx.into_stream(), + thread_name, + join_handle, + channels: Some(Channels { + outbound_tx: outbound_tx.into_sink(), + inbound_rx: inbound_rx.into_stream(), + }), }) } fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - ready!(self.outbound_tx.poll_ready_unpin(cx)).map_err(|_| UdpSocketThreadStopped)?; + ready!(self.channels_mut()?.outbound_tx.poll_ready_unpin(cx)) + .map_err(|_| UdpSocketThreadStopped)?; Poll::Ready(Ok(())) } fn send(&mut self, datagram: DatagramOut) -> Result<()> { - self.outbound_tx + self.channels_mut()? + .outbound_tx .start_send_unpin(datagram) .map_err(|_| UdpSocketThreadStopped)?; @@ -298,10 +311,36 @@ impl ThreadedUdpSocket { } fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll> { - let iter = ready!(self.inbound_rx.poll_next_unpin(cx)).ok_or(UdpSocketThreadStopped)?; + let iter = ready!(self.channels_mut()?.inbound_rx.poll_next_unpin(cx)) + .ok_or(UdpSocketThreadStopped)?; Poll::Ready(iter) } + + fn channels_mut(&mut self) -> Result<&mut Channels> { + self.channels.as_mut().context("Missing channels") + } +} + +impl Drop for ThreadedUdpSocket { + fn drop(&mut self) { + let start = Instant::now(); + + let _ = self.channels.take(); + + const TIMEOUT: Duration = Duration::from_millis(500); + + while !self.join_handle.is_finished() { + let elapsed = start.elapsed(); + + if elapsed > TIMEOUT { + tracing::debug!(name = %self.thread_name, "Thread did not stop within {TIMEOUT:?}"); + return; + } + } + + tracing::debug!(name = %self.thread_name, duration = ?start.elapsed(), "Background thread stopped"); + } } fn listen(