diff --git a/rust/Cargo.lock b/rust/Cargo.lock index fa422bf93..a743cb90b 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -4908,10 +4908,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.0" -source = "git+https://github.com/quinn-rs/quinn?branch=main#88f48b0179f358cea0b76fc550e629007bfc957d" +version = "0.5.2" +source = "git+https://github.com/quinn-rs/quinn?branch=main#3f489e2eab014ddd04de58e570ba56e9b027f0bc" dependencies = [ - "bytes", "libc", "once_cell", "socket2 0.5.7", diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index bab1ea803..742a3290d 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -3,7 +3,6 @@ use crate::{ dns::DnsQuery, sockets::{Received, Sockets}, }; -use bytes::Bytes; use connlib_shared::messages::DnsServer; use futures::Future; use futures_bounded::FuturesTupleSet; @@ -16,7 +15,6 @@ use hickory_resolver::{ AsyncResolver, TokioHandle, }; use ip_packet::{IpPacket, MutableIpPacket}; -use quinn_udp::Transmit; use socket_factory::SocketFactory; use std::{ collections::HashMap, @@ -187,13 +185,7 @@ impl Io { } pub fn send_network(&mut self, transmit: snownet::Transmit) -> io::Result<()> { - self.sockets.try_send(Transmit { - destination: transmit.dst, - ecn: None, - contents: Bytes::copy_from_slice(&transmit.payload), - segment_size: None, - src_ip: transmit.src.map(|s| s.ip()), - })?; + self.sockets.send(transmit)?; Ok(()) } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 0e9b4790a..77b20d852 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -2,6 +2,7 @@ use core::slice; use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState}; use socket_factory::SocketFactory; use std::{ + collections::VecDeque, io::{self, IoSliceMut}, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, task::{ready, Context, Poll}, @@ -64,23 +65,18 @@ impl Sockets { Poll::Ready(Ok(())) } - pub fn try_send(&mut self, transmit: quinn_udp::Transmit) -> io::Result<()> { - match transmit.destination { - SocketAddr::V4(dst) => { - let socket = self.socket_v4.as_mut().ok_or(io::Error::new( - io::ErrorKind::NotConnected, - format!("failed send packet to {dst}: no IPv4 socket"), - ))?; - socket.send(transmit); - } - SocketAddr::V6(dst) => { - let socket = self.socket_v6.as_mut().ok_or(io::Error::new( - io::ErrorKind::NotConnected, - format!("failed send packet to {dst}: no IPv6 socket"), - ))?; - socket.send(transmit); - } - } + pub fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> { + let socket = match transmit.dst { + SocketAddr::V4(dst) => self.socket_v4.as_mut().ok_or(io::Error::new( + io::ErrorKind::NotConnected, + format!("failed send packet to {dst}: no IPv4 socket"), + ))?, + SocketAddr::V6(dst) => self.socket_v6.as_mut().ok_or(io::Error::new( + io::ErrorKind::NotConnected, + format!("failed send packet to {dst}: no IPv6 socket"), + ))?, + }; + socket.send(transmit)?; Ok(()) } @@ -166,7 +162,7 @@ struct Socket { port: u16, socket: UdpSocket, - buffered_transmits: Vec, + buffered_transmits: VecDeque>, } impl Socket { @@ -181,7 +177,7 @@ impl Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, socket, - buffered_transmits: Vec::new(), + buffered_transmits: VecDeque::new(), }) } @@ -252,42 +248,62 @@ impl Socket { fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.socket.try_io(Interest::WRITABLE, || { - self.state - .send((&self.socket).into(), &self.buffered_transmits) - }) { - Ok(0) => break, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, - Err(e) => return Poll::Ready(Err(e)), + ready!(self.socket.poll_send_ready(cx))?; // Ensure we are ready to send. - Ok(num_sent) => { - self.buffered_transmits.drain(..num_sent); - - // I am not sure if we'd ever send less than what is in `buffered_transmits`. - // loop once more to be sure we `break` on either an empty buffer or on `WouldBlock`. - } + let Some(transmit) = self.buffered_transmits.pop_front() else { + break; }; + + match self.try_send(&transmit) { + Ok(()) => continue, // Try to send another packet. + Err(e) => { + self.buffered_transmits.push_front(transmit); // Don't lose the packet if we fail. + + if e.kind() == io::ErrorKind::WouldBlock { + continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`. + } + + return Poll::Ready(Err(e)); + } + } } - // Ensure we are ready to send more data. - ready!(self.socket.poll_send_ready(cx)?); - - assert!( - self.buffered_transmits.is_empty(), - "buffer must be empty if we are ready to send more data" - ); + assert!(self.buffered_transmits.is_empty()); Poll::Ready(Ok(())) } - fn send(&mut self, transmit: quinn_udp::Transmit) { - tracing::trace!(target: "wire::net::send", src = ?transmit.src_ip, dst = %transmit.destination, num_bytes = %transmit.contents.len()); - - self.buffered_transmits.push(transmit); + fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> { + tracing::trace!(target: "wire::net::send", src = ?transmit.src, dst = %transmit.dst, num_bytes = %transmit.payload.len()); debug_assert!( self.buffered_transmits.len() < 10_000, "We are not flushing the packets for some reason" ); + + match self.try_send(&transmit) { + Ok(()) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + tracing::trace!("Buffering packet because socket is busy"); + + self.buffered_transmits.push_back(transmit.into_owned()); + Ok(()) + } + Err(e) => Err(e), + } + } + + fn try_send(&self, transmit: &snownet::Transmit) -> io::Result<()> { + let transmit = quinn_udp::Transmit { + destination: transmit.dst, + ecn: None, + contents: &transmit.payload, + segment_size: None, + src_ip: transmit.src.map(|s| s.ip()), + }; + + self.socket.try_io(Interest::WRITABLE, || { + self.state.send((&self.socket).into(), &transmit) + }) } }