diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 4720ba4b0..fd904f83d 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -4,6 +4,7 @@ use crate::{ ip_packet::{IpPacket, MutableIpPacket}, sockets::{Received, Sockets}, }; +use bytes::Bytes; use connlib_shared::messages::DnsServer; use futures_bounded::FuturesTupleSet; use futures_util::FutureExt as _; @@ -11,7 +12,7 @@ use hickory_resolver::{ config::{NameServerConfig, Protocol, ResolverConfig}, TokioAsyncResolver, }; -use snownet::Transmit; +use quinn_udp::Transmit; use std::{ collections::HashMap, io, @@ -96,7 +97,7 @@ impl Io { return Poll::Ready(Ok(Input::Network(network))); } - ready!(self.sockets.poll_send_ready(cx))?; // Packets read from the device need to be written to a socket, let's make sure the socket can take more packets. + ready!(self.sockets.poll_flush(cx))?; if let Poll::Ready(packet) = self.device.poll_read(device_buffer, cx)? { return Poll::Ready(Ok(Input::Device(packet))); @@ -159,8 +160,14 @@ impl Io { } } - pub fn send_network(&self, transmit: Transmit) -> io::Result<()> { - self.sockets.try_send(&transmit)?; + pub fn send_network(&mut self, transmit: snownet::Transmit) -> io::Result<()> { + self.sockets.try_send(Transmit { + destination: transmit.dst, + ecn: None, + contents: Bytes::copy_from_slice(&transmit.payload), + segment_size: None, + src_ip: transmit.src.map(|s| s.ip()), + })?; Ok(()) } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index d0bc547d2..4ddf4bd0d 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,4 +1,3 @@ -use bytes::Bytes; use core::slice; use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState}; use socket2::{SockAddr, Type}; @@ -10,7 +9,6 @@ use std::{ use tokio::{io::Interest, net::UdpSocket}; use crate::Result; -use snownet::Transmit; pub struct Sockets { socket_v4: Option, @@ -68,35 +66,40 @@ impl Sockets { self.socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) } - pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { - if let Some(socket) = self.socket_v4.as_ref() { - ready!(socket.poll_send_ready(cx))?; + /// Flushes all buffered data on the sockets. + /// + /// Returns `Ready` if the socket is able to accept more data. + pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(socket) = self.socket_v4.as_mut() { + ready!(socket.poll_flush(cx))?; } - if let Some(socket) = self.socket_v6.as_ref() { - ready!(socket.poll_send_ready(cx))?; + if let Some(socket) = self.socket_v6.as_mut() { + ready!(socket.poll_flush(cx))?; } Poll::Ready(Ok(())) } - pub fn try_send(&self, transmit: &Transmit) -> io::Result { - match transmit.dst { + pub fn try_send(&mut self, transmit: quinn_udp::Transmit) -> io::Result<()> { + match transmit.destination { SocketAddr::V4(_) => { - let socket = self.socket_v4.as_ref().ok_or(io::Error::new( + let socket = self.socket_v4.as_mut().ok_or(io::Error::new( io::ErrorKind::NotConnected, "no IPv4 socket", ))?; - Ok(socket.try_send_to(transmit.src, transmit.dst, &transmit.payload)?) + socket.send(transmit); } SocketAddr::V6(_) => { - let socket = self.socket_v6.as_ref().ok_or(io::Error::new( + let socket = self.socket_v6.as_mut().ok_or(io::Error::new( io::ErrorKind::NotConnected, "no IPv6 socket", ))?; - Ok(socket.try_send_to(transmit.src, transmit.dst, &transmit.payload)?) + socket.send(transmit); } } + + Ok(()) } pub fn poll_recv_from<'b>( @@ -179,6 +182,8 @@ struct Socket { state: UdpSocketState, port: u16, socket: UdpSocket, + + buffered_transmits: Vec, } impl Socket { @@ -190,6 +195,7 @@ impl Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, socket: tokio::net::UdpSocket::from_std(socket)?, + buffered_transmits: Vec::new(), }) } @@ -201,6 +207,7 @@ impl Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, socket: tokio::net::UdpSocket::from_std(socket)?, + buffered_transmits: Vec::new(), }) } @@ -214,6 +221,7 @@ impl Socket { port, socket, state, + .. } = self; let bufs = &mut [IoSliceMut::new(buffer)]; @@ -254,28 +262,45 @@ impl Socket { } } - fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.socket.poll_send_ready(cx) + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.socket.try_io(Interest::WRITABLE, || { + self.state + .send((&self.socket).into(), &self.buffered_transmits) + }) { + Ok(0) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => return Poll::Ready(Err(e)), + + Ok(num_sent) => { + self.buffered_transmits.drain(..num_sent); + + // I am not sure if we'd ever send less than what is in `buffered_transmits`. + // loop once more to be sure we `break` on either an empty buffer or on `WouldBlock`. + } + }; + } + + // Ensure we are ready to send more data. + ready!(self.socket.poll_send_ready(cx)?); + + assert!( + self.buffered_transmits.is_empty(), + "buffer must be empty if we are ready to send more data" + ); + + Poll::Ready(Ok(())) } - fn try_send_to( - &self, - local: Option, - dest: SocketAddr, - buf: &[u8], - ) -> io::Result { - tracing::trace!(target: "wire", to = "network", src = ?local, dst = %dest, num_bytes = %buf.len()); + fn send(&mut self, transmit: quinn_udp::Transmit) { + tracing::trace!(target: "wire", to = "network", src = ?transmit.src_ip, dst = %transmit.destination, num_bytes = %transmit.contents.len()); - self.state.send( - (&self.socket).into(), - &[quinn_udp::Transmit { - destination: dest, - ecn: None, - contents: Bytes::copy_from_slice(buf), - segment_size: None, - src_ip: local.map(|s| s.ip()), - }], - ) + self.buffered_transmits.push(transmit); + + debug_assert!( + self.buffered_transmits.len() < 10_000, + "We are not flushing the packets for some reason" + ); } }