fix(connlib): wait for sockets to be closed before rebinding (#9996)

Our `ThreadedUdpSocket` uses a background thread for the actual socket
operation. It merely represents a handle to send and receive from these
sockets but not the socket itself. Dropping the handle will shutdown the
background thread but that is an asynchronous operation.

In order to be sure that we can rebind the same port, we need to wait
for the background thread to stop.

We thus add a `Drop` implementation for the `ThreadedUdpSocket` that
waits for its background thread to disappear before it continues.

Resolves: #9992
This commit is contained in:
Thomas Eizinger
2025-07-25 13:09:13 +10:00
committed by GitHub
parent ccc736e63e
commit e5ee8e3572

View File

@@ -1,9 +1,10 @@
use crate::otel;
use anyhow::Result;
use anyhow::{Context as _, Result};
use futures::{SinkExt, StreamExt, ready};
use gat_lending_iterator::LendingIterator;
use socket_factory::DatagramOut;
use socket_factory::{DatagramIn, DatagramSegmentIter, SocketFactory, UdpSocket};
use std::time::{Duration, Instant};
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
@@ -158,6 +159,12 @@ where
}
struct ThreadedUdpSocket {
thread_name: String,
join_handle: std::thread::JoinHandle<()>,
channels: Option<Channels>,
}
struct Channels {
outbound_tx: flume::r#async::SendSink<'static, DatagramOut>,
inbound_rx: flume::r#async::RecvStream<'static, Result<DatagramSegmentIter>>,
}
@@ -168,91 +175,59 @@ impl ThreadedUdpSocket {
let (inbound_tx, inbound_rx) = flume::bounded(10);
let (error_tx, error_rx) = flume::bounded(0);
std::thread::Builder::new()
.name(match preferred_addr {
SocketAddr::V4(_) => "UDP IPv4".to_owned(),
SocketAddr::V6(_) => "UDP IPv6".to_owned(),
})
.spawn(move || {
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(r) => r,
Err(e) => {
let _ = error_tx.send(Err(e));
return;
}
};
runtime.block_on(async move {
let mut socket = match listen(
sf,
// Listen on the preferred address, fall back to picking a free port if that doesn't work
&[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)],
) {
Ok(s) => s,
let thread_name = match preferred_addr {
SocketAddr::V4(_) => "UDP IPv4".to_owned(),
SocketAddr::V6(_) => "UDP IPv6".to_owned(),
};
let join_handle =
std::thread::Builder::new()
.name(thread_name.clone())
.spawn(move || {
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(r) => r,
Err(e) => {
let _ = error_tx.send(Err(e));
return;
}
};
let io_error_counter = opentelemetry::global::meter("connlib")
.u64_counter("system.network.errors")
.with_description("Number of IO errors encountered")
.with_unit("{error}")
.build();
if let Err(e) = socket.set_buffer_sizes(
socket_factory::SEND_BUFFER_SIZE,
socket_factory::RECV_BUFFER_SIZE,
) {
tracing::warn!("Failed to set socket buffer sizes: {e}");
runtime.block_on(async move {
let mut socket = match listen(
sf,
// Listen on the preferred address, fall back to picking a free port if that doesn't work
&[preferred_addr, SocketAddr::new(preferred_addr.ip(), 0)],
) {
Ok(s) => s,
Err(e) => {
let _ = error_tx.send(Err(e));
return;
}
};
let send = pin!(async {
while let Ok(datagram) = outbound_rx.recv_async().await {
if let Err(e) = socket.send(datagram).await {
if let Some(io) = e.downcast_ref::<io::Error>() {
io_error_counter.add(
1,
&[
otel::attr::network_io_direction_transmit(),
otel::attr::network_type_for_addr(preferred_addr),
otel::attr::io_error_type(io),
otel::attr::io_error_code(io),
],
);
}
let io_error_counter = opentelemetry::global::meter("connlib")
.u64_counter("system.network.errors")
.with_description("Number of IO errors encountered")
.with_unit("{error}")
.build();
// We use the inbound_tx channel to send the error back to the main thread.
if inbound_tx.send_async(Err(e)).await.is_err() {
tracing::debug!(
"Channel for inbound datagrams closed; exiting UDP thread"
);
break;
}
};
}
if let Err(e) = socket.set_buffer_sizes(
socket_factory::SEND_BUFFER_SIZE,
socket_factory::RECV_BUFFER_SIZE,
) {
tracing::warn!("Failed to set socket buffer sizes: {e}");
}
tracing::debug!(
"Channel for outbound datagrams closed; exiting UDP thread"
);
});
let receive = pin!(async {
loop {
let result = socket.recv_from().await;
if let Some(io) = result
.as_ref()
.err()
.and_then(|e| e.downcast_ref::<io::Error>())
{
let send = pin!(async {
while let Ok(datagram) = outbound_rx.recv_async().await {
if let Err(e) = socket.send(datagram).await {
if let Some(io) = e.downcast_ref::<io::Error>() {
io_error_counter.add(
1,
&[
otel::attr::network_io_direction_receive(),
otel::attr::network_io_direction_transmit(),
otel::attr::network_type_for_addr(preferred_addr),
otel::attr::io_error_type(io),
otel::attr::io_error_code(io),
@@ -260,37 +235,75 @@ impl ThreadedUdpSocket {
);
}
if inbound_tx.send_async(result).await.is_err() {
// We use the inbound_tx channel to send the error back to the main thread.
if inbound_tx.send_async(Err(e)).await.is_err() {
tracing::debug!(
"Channel for inbound datagrams closed; exiting UDP thread"
);
break;
}
};
}
tracing::debug!("Channel for outbound datagrams closed; exiting UDP thread");
});
let receive = pin!(async {
loop {
let result = socket.recv_from().await;
if let Some(io) = result
.as_ref()
.err()
.and_then(|e| e.downcast_ref::<io::Error>())
{
io_error_counter.add(
1,
&[
otel::attr::network_io_direction_receive(),
otel::attr::network_type_for_addr(preferred_addr),
otel::attr::io_error_type(io),
otel::attr::io_error_code(io),
],
);
}
});
let _ = error_tx.send(Ok(()));
if inbound_tx.send_async(result).await.is_err() {
tracing::debug!(
"Channel for inbound datagrams closed; exiting UDP thread"
);
break;
}
}
});
futures::future::select(send, receive).await;
})
})?;
let _ = error_tx.send(Ok(()));
futures::future::select(send, receive).await;
})
})?;
error_rx.recv().map_err(io::Error::other)??;
Ok(Self {
outbound_tx: outbound_tx.into_sink(),
inbound_rx: inbound_rx.into_stream(),
thread_name,
join_handle,
channels: Some(Channels {
outbound_tx: outbound_tx.into_sink(),
inbound_rx: inbound_rx.into_stream(),
}),
})
}
fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
ready!(self.outbound_tx.poll_ready_unpin(cx)).map_err(|_| UdpSocketThreadStopped)?;
ready!(self.channels_mut()?.outbound_tx.poll_ready_unpin(cx))
.map_err(|_| UdpSocketThreadStopped)?;
Poll::Ready(Ok(()))
}
fn send(&mut self, datagram: DatagramOut) -> Result<()> {
self.outbound_tx
self.channels_mut()?
.outbound_tx
.start_send_unpin(datagram)
.map_err(|_| UdpSocketThreadStopped)?;
@@ -298,10 +311,36 @@ impl ThreadedUdpSocket {
}
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<Result<DatagramSegmentIter>> {
let iter = ready!(self.inbound_rx.poll_next_unpin(cx)).ok_or(UdpSocketThreadStopped)?;
let iter = ready!(self.channels_mut()?.inbound_rx.poll_next_unpin(cx))
.ok_or(UdpSocketThreadStopped)?;
Poll::Ready(iter)
}
fn channels_mut(&mut self) -> Result<&mut Channels> {
self.channels.as_mut().context("Missing channels")
}
}
impl Drop for ThreadedUdpSocket {
fn drop(&mut self) {
let start = Instant::now();
let _ = self.channels.take();
const TIMEOUT: Duration = Duration::from_millis(500);
while !self.join_handle.is_finished() {
let elapsed = start.elapsed();
if elapsed > TIMEOUT {
tracing::debug!(name = %self.thread_name, "Thread did not stop within {TIMEOUT:?}");
return;
}
}
tracing::debug!(name = %self.thread_name, duration = ?start.elapsed(), "Background thread stopped");
}
}
fn listen(