From e3688a475ed1c6740ac37ba08a2dcde72f2ea20b Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Wed, 4 Sep 2024 09:59:33 -0700 Subject: [PATCH] refactor(connlib): only buffer 1 unsent packet if socket is busy (#6563) Currently, we buffer UDP packets whenever the socket is busy and try to flush them out at a later point. This requires allocations and is tricky to get right. In order to solve both of these problems, we refactor `snownet` to return us an `EncryptedPacket` instead of a `Transmit`. An `EncryptedPacket` is an indirection-abstraction that can be turned into a `Transmit` given an `EncryptBuffer`. This combination of types allows us to hold on to the `EncryptedPacket` (which does not contain any references itself) in the `io` component whilst we are waiting for the socket to be ready to send again. This means we will immediately suspend the event loop in case the socket is no longer ready for sending and resend the datagram in the `EncryptBuffer` once we get re-polled. --- rust/bin-shared/src/tun_device_manager.rs | 9 ++- rust/connlib/clients/shared/src/eventloop.rs | 4 ++ rust/connlib/snownet/src/allocation.rs | 21 +++--- rust/connlib/snownet/src/lib.rs | 4 +- rust/connlib/snownet/src/node.rs | 73 ++++++++++++++------ rust/connlib/tunnel/src/client.rs | 11 +-- rust/connlib/tunnel/src/gateway.rs | 11 +-- rust/connlib/tunnel/src/io.rs | 47 ++++++++++++- rust/connlib/tunnel/src/lib.rs | 47 +++++++++---- rust/connlib/tunnel/src/sockets.rs | 13 ++-- rust/connlib/tunnel/src/tests/sim_client.rs | 11 ++- rust/connlib/tunnel/src/tests/sim_gateway.rs | 16 ++++- rust/socket-factory/src/lib.rs | 58 ++-------------- 13 files changed, 192 insertions(+), 133 deletions(-) diff --git a/rust/bin-shared/src/tun_device_manager.rs b/rust/bin-shared/src/tun_device_manager.rs index ef2305883..5619f0a4f 100644 --- a/rust/bin-shared/src/tun_device_manager.rs +++ b/rust/bin-shared/src/tun_device_manager.rs @@ -59,6 +59,10 @@ mod tests { ))) .unwrap(); + std::future::poll_fn(|cx| socket.poll_send_ready(cx)) + .await + .unwrap(); + // Send a STUN request. socket .send(DatagramOut { @@ -70,11 +74,6 @@ mod tests { }) .unwrap(); - // First send seems to always result as would block - std::future::poll_fn(|cx| socket.poll_flush(cx)) - .await - .unwrap(); - let task = std::future::poll_fn(|cx| { let mut buf = [0u8; 1000]; let result = std::task::ready!(socket.poll_recv_from(&mut buf, cx)); diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 457de12b3..ff2d60ed4 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -15,6 +15,7 @@ use firezone_tunnel::ClientTunnel; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ collections::{BTreeMap, BTreeSet}, + io, net::IpAddr, task::{Context, Poll}, }; @@ -91,6 +92,9 @@ where self.handle_tunnel_event(event); continue; } + Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } Poll::Ready(Err(e)) => { tracing::warn!("Tunnel error: {e}"); continue; diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index c7d529051..4fa3ffbd5 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -3,6 +3,7 @@ use crate::{ node::{CandidateEvent, SessionId, Transmit}, ringbuffer::RingBuffer, utils::earliest, + EncryptedPacket, }; use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; @@ -721,29 +722,23 @@ impl Allocation { ); } - /// Encodes the packet contained in the given buffer into a [`Transmit`]. - /// - /// This function assumes that the first 4 bytes of `buffer` have been reserved for the header of the channel-data message. - pub fn encode_to_borrowed_transmit<'b>( + pub fn encode_to_encrypted_packet( &self, peer: SocketAddr, - buffer: &'b mut [u8], + buffer: &mut [u8], now: Instant, - ) -> Option> { + ) -> Option { let buffer_len = buffer.len(); let packet_len = buffer_len - 4; let channel_number = self.channel_bindings.connected_channel_to_peer(peer, now)?; - let total_length = crate::channel_data::encode_header_to_slice( - &mut buffer[..4], - channel_number, - packet_len, - ); + crate::channel_data::encode_header_to_slice(&mut buffer[..4], channel_number, packet_len); - Some(Transmit { + Some(EncryptedPacket { src: None, dst: self.active_socket?, - payload: Cow::Borrowed(&buffer[..total_length]), + packet_start: 0, + packet_len: buffer_len, }) } diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index b944a89b8..454b9d715 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -11,7 +11,7 @@ mod utils; pub use allocation::RelaySocket; pub use node::{ - Answer, Client, ClientNode, Credentials, Error, Event, Node, Offer, Server, ServerNode, - Transmit, HANDSHAKE_TIMEOUT, + Answer, Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, + Offer, Server, ServerNode, Transmit, HANDSHAKE_TIMEOUT, }; pub use stats::{ConnectionStats, NodeStats}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 5683f3f2c..ab418e38a 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -287,14 +287,14 @@ where /// - `Ok(None)` if the packet was handled internally, for example, a response from a TURN server. /// - `Ok(Some)` if the packet was an encrypted wireguard packet from a peer. /// The `Option` contains the connection on which the packet was decrypted. - pub fn decapsulate<'s>( + pub fn decapsulate<'b>( &mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant, - buffer: &'s mut [u8], - ) -> Result)>, Error> { + buffer: &'b mut [u8], + ) -> Result)>, Error> { self.add_local_as_host_candidate(local)?; let (from, packet, relayed) = match self.allocations_try_handle(from, local, packet, now) { @@ -325,12 +325,13 @@ where /// Wireguard is an IP tunnel, so we "enforce" that only IP packets are sent through it. /// We say "enforce" an [`IpPacket`] can be created from an (almost) arbitrary byte buffer at virtually no cost. /// Nevertheless, using [`IpPacket`] in our API has good documentation value. - pub fn encapsulate<'s>( - &'s mut self, + pub fn encapsulate( + &mut self, connection: TId, packet: IpPacket<'_>, now: Instant, - ) -> Result>, Error> { + buffer: &mut EncryptBuffer, + ) -> Result, Error> { let conn = self .connections .get_established_mut(&connection) @@ -341,7 +342,7 @@ where // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. let Some(packet_len) = conn - .encapsulate(packet.packet(), &mut self.buffer[4..], now)? + .encapsulate(packet.packet(), &mut buffer.inner[4..], now)? .map(|p| p.len()) // Mapping to len() here terminate the mutable borrow of buffer, allowing re-borrowing further down. else { @@ -355,30 +356,26 @@ where PeerSocket::Direct { dest: remote, source, - } => { - // Re-borrow the actual packet. - let packet = &self.buffer[packet_start..packet_end]; - - Ok(Some(Transmit { - src: Some(source), - dst: remote, - payload: Cow::Borrowed(packet), - })) - } + } => Ok(Some(EncryptedPacket { + src: Some(source), + dst: remote, + packet_start, + packet_len, + })), PeerSocket::Relay { relay, dest: peer } => { let Some(allocation) = self.allocations.get(&relay) else { tracing::warn!(%relay, "No allocation"); return Ok(None); }; - let packet = &mut self.buffer[..packet_end]; + let packet = &mut buffer.inner[..packet_end]; - let Some(transmit) = allocation.encode_to_borrowed_transmit(peer, packet, now) + let Some(enc_packet) = allocation.encode_to_encrypted_packet(peer, packet, now) else { tracing::warn!(%peer, "No channel"); return Ok(None); }; - Ok(Some(transmit)) + Ok(Some(enc_packet)) } } } @@ -1240,6 +1237,42 @@ pub enum Event { ConnectionClosed(TId), } +pub struct EncryptBuffer { + inner: Vec, +} + +impl EncryptBuffer { + pub fn new(len: usize) -> Self { + Self { + inner: vec![0u8; len], + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct EncryptedPacket { + pub(crate) src: Option, + pub(crate) dst: SocketAddr, + pub(crate) packet_start: usize, + pub(crate) packet_len: usize, +} + +impl EncryptedPacket { + pub fn to_transmit(self, buf: &EncryptBuffer) -> Transmit<'_> { + Transmit { + src: self.src, + dst: self.dst, + payload: Cow::Borrowed( + &buf.inner[self.packet_start..(self.packet_start + self.packet_len)], + ), + } + } + + pub fn dst(&self) -> SocketAddr { + self.dst + } +} + #[derive(Clone, PartialEq, PartialOrd, Eq, Ord)] pub struct Transmit<'a> { /// The local interface from which this packet should be sent. diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 616510af7..1916f600a 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -22,7 +22,7 @@ use crate::{ClientEvent, ClientTunnel, Tun}; use domain::base::Message; use lru::LruCache; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{ClientNode, RelaySocket, Transmit}; +use snownet::{ClientNode, EncryptBuffer, RelaySocket, Transmit}; use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; @@ -419,11 +419,12 @@ impl ClientState { ) } - pub(crate) fn encapsulate<'s>( - &'s mut self, + pub(crate) fn encapsulate( + &mut self, packet: MutableIpPacket<'_>, now: Instant, - ) -> Option> { + buffer: &mut EncryptBuffer, + ) -> Option { let (packet, dst) = match self.try_handle_dns_query(packet, now) { Ok(response) => { self.buffered_packets.push_back(response?.to_owned()); @@ -469,7 +470,7 @@ impl ClientState { let transmit = self .node - .encapsulate(gid, packet.as_immutable(), now) + .encapsulate(gid, packet.as_immutable(), now, buffer) .inspect_err(|e| tracing::debug!(%gid, "Failed to encapsulate: {e}")) .ok()??; diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 2a891ec98..61f844176 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -13,7 +13,7 @@ use connlib_shared::{DomainName, StaticSecret}; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::{IpPacket, MutableIpPacket}; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{RelaySocket, ServerNode}; +use snownet::{EncryptBuffer, RelaySocket, ServerNode}; use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::{Duration, Instant}; @@ -155,11 +155,12 @@ impl GatewayState { self.node.public_key() } - pub(crate) fn encapsulate<'s>( - &'s mut self, + pub(crate) fn encapsulate( + &mut self, packet: MutableIpPacket<'_>, now: Instant, - ) -> Option> { + buffer: &mut EncryptBuffer, + ) -> Option { let dst = packet.destination(); if !is_client(dst) { @@ -180,7 +181,7 @@ impl GatewayState { let transmit = self .node - .encapsulate(peer.id(), packet.as_immutable(), now) + .encapsulate(peer.id(), packet.as_immutable(), now, buffer) .inspect_err(|e| tracing::debug!(%cid, "Failed to encapsulate: {e}")) .ok()??; diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 30f63c6c1..f11dbc27c 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,7 @@ use crate::{device_channel::Device, sockets::Sockets, BUF_SIZE}; use futures_util::FutureExt as _; use ip_packet::{IpPacket, MutableIpPacket}; +use snownet::{EncryptBuffer, EncryptedPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ io, @@ -18,6 +19,7 @@ pub struct Io { device: Device, /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, + unwritten_packet: Option, _tcp_socket_factory: Arc>, udp_socket_factory: Arc>, @@ -48,6 +50,7 @@ impl Io { sockets, _tcp_socket_factory: tcp_socket_factory, udp_socket_factory, + unwritten_packet: None, } } @@ -61,13 +64,14 @@ impl Io { ip4_buffer: &'b1 mut [u8], ip6_bffer: &'b1 mut [u8], device_buffer: &'b2 mut [u8], + encrypt_buffer: &EncryptBuffer, ) -> Poll>>>> { + ready!(self.poll_send_unwritten(cx, encrypt_buffer)?); + if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size)))); } - ready!(self.sockets.poll_flush(cx))?; - if let Poll::Ready(packet) = self.device.poll_read(device_buffer, cx)? { return Poll::Ready(Ok(Input::Device(packet))); } @@ -84,6 +88,23 @@ impl Io { Poll::Pending } + fn poll_send_unwritten( + &mut self, + cx: &mut Context<'_>, + buf: &EncryptBuffer, + ) -> Poll> { + ready!(self.sockets.poll_send_ready(cx))?; + + // If the `unwritten_packet` is set, `EncryptBuffer` is still holding a packet that we need so send. + let Some(unwritten_packet) = self.unwritten_packet.take() else { + return Poll::Ready(Ok(())); + }; + + self.send_encrypted_packet(unwritten_packet, buf)?; + + Poll::Ready(Ok(())) + } + pub fn device_mut(&mut self) -> &mut Device { &mut self.device } @@ -114,6 +135,28 @@ impl Io { Ok(()) } + pub fn send_encrypted_packet( + &mut self, + packet: EncryptedPacket, + buf: &EncryptBuffer, + ) -> io::Result<()> { + let transmit = packet.to_transmit(buf); + let res = self.send_network(transmit); + + if res + .as_ref() + .err() + .is_some_and(|e| e.kind() == io::ErrorKind::WouldBlock) + { + tracing::debug!(dst = %packet.dst(), "Socket busy"); + self.unwritten_packet = Some(packet); + } + + res?; + + Ok(()) + } + pub fn send_device(&self, packet: IpPacket<'_>) -> io::Result<()> { self.device.write(packet)?; diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 8ea3f70e8..3252c5a92 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -64,6 +64,7 @@ pub type ClientTunnel = Tunnel; pub use client::ClientState; pub use gateway::{GatewayState, IPV4_PEERS, IPV6_PEERS}; +use snownet::EncryptBuffer; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. /// @@ -81,8 +82,12 @@ pub struct Tunnel { ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, - /// Buffer for processing a single IP packet. - packet_buffer: Box<[u8; BUF_SIZE]>, + /// Buffer for reading a single IP packet. + device_read_buf: Box<[u8; BUF_SIZE]>, + /// Buffer for decryping a single packet. + decrypt_buf: Box<[u8; BUF_SIZE]>, + /// Buffer for encrypting a single packet. + encrypt_buf: EncryptBuffer, } impl ClientTunnel { @@ -95,9 +100,11 @@ impl ClientTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: ClientState::new(private_key, known_hosts, rand::random()), - packet_buffer: Box::new([0u8; BUF_SIZE]), + device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + encrypt_buf: EncryptBuffer::new(BUF_SIZE), + decrypt_buf: Box::new([0u8; BUF_SIZE]), } } @@ -132,18 +139,23 @@ impl ClientTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), + &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout); continue; } Poll::Ready(io::Input::Device(packet)) => { - let Some(transmit) = self.role_state.encapsulate(packet, Instant::now()) else { + let Some(enc_packet) = + self.role_state + .encapsulate(packet, Instant::now(), &mut self.encrypt_buf) + else { continue; }; - self.io.send_network(transmit)?; + self.io + .send_encrypted_packet(enc_packet, &self.encrypt_buf)?; continue; } @@ -154,7 +166,7 @@ impl ClientTunnel { received.from, received.packet, std::time::Instant::now(), - self.packet_buffer.as_mut(), + self.decrypt_buf.as_mut(), ) else { continue; }; @@ -185,9 +197,11 @@ impl GatewayTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: GatewayState::new(private_key, rand::random()), - packet_buffer: Box::new([0u8; BUF_SIZE]), + device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + encrypt_buf: EncryptBuffer::new(BUF_SIZE), + decrypt_buf: Box::new([0u8; BUF_SIZE]), } } @@ -217,21 +231,24 @@ impl GatewayTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), + &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout, Utc::now()); continue; } Poll::Ready(io::Input::Device(packet)) => { - let Some(transmit) = self - .role_state - .encapsulate(packet, std::time::Instant::now()) - else { + let Some(enc_packet) = self.role_state.encapsulate( + packet, + std::time::Instant::now(), + &mut self.encrypt_buf, + ) else { continue; }; - self.io.send_network(transmit)?; + self.io + .send_encrypted_packet(enc_packet, &self.encrypt_buf)?; continue; } @@ -242,7 +259,7 @@ impl GatewayTunnel { received.from, received.packet, std::time::Instant::now(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), ) else { continue; }; diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 7d1326571..87efb6744 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -45,16 +45,13 @@ impl Sockets { Poll::Ready(()) } - /// Flushes all buffered data on the sockets. - /// - /// Returns `Ready` if the socket is able to accept more data. - pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(socket) = self.socket_v4.as_mut() { - ready!(socket.poll_flush(cx))?; + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + if let Some(socket) = self.socket_v4.as_ref() { + ready!(socket.poll_send_ready(cx))?; } - if let Some(socket) = self.socket_v6.as_mut() { - ready!(socket.poll_flush(cx))?; + if let Some(socket) = self.socket_v6.as_ref() { + ready!(socket.poll_send_ready(cx))?; } Poll::Ready(Ok(())) diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 9141451e0..2bd00b46b 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -28,7 +28,7 @@ use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use itertools::Itertools as _; use prop::collection; use proptest::prelude::*; -use snownet::Transmit; +use snownet::{EncryptBuffer, Transmit}; use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}, mem, @@ -57,6 +57,7 @@ pub(crate) struct SimClient { pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket<'static>>, buffer: Vec, + enc_buffer: EncryptBuffer, } impl SimClient { @@ -71,6 +72,7 @@ impl SimClient { sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), buffer: vec![0u8; (1 << 16) - 1], + enc_buffer: EncryptBuffer::new((1 << 16) - 1), } } @@ -147,7 +149,12 @@ impl SimClient { } } - Some(self.sut.encapsulate(packet, now)?.into_owned()) + Some( + self.sut + .encapsulate(packet, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(), + ) } pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index f7797fe9d..c05c01c05 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -11,7 +11,7 @@ use connlib_shared::{ }; use ip_packet::IpPacket; use proptest::prelude::*; -use snownet::Transmit; +use snownet::{EncryptBuffer, Transmit}; use std::{ collections::{BTreeMap, BTreeSet}, net::IpAddr, @@ -27,6 +27,7 @@ pub(crate) struct SimGateway { pub(crate) received_icmp_requests: BTreeMap>, buffer: Vec, + enc_buffer: EncryptBuffer, } impl SimGateway { @@ -36,6 +37,7 @@ impl SimGateway { sut, received_icmp_requests: Default::default(), buffer: vec![0u8; (1 << 16) - 1], + enc_buffer: EncryptBuffer::new((1 << 16) - 1), } } @@ -78,7 +80,11 @@ impl SimGateway { self.received_icmp_requests.insert(payload, packet.clone()); let echo_response = ip_packet::make::icmp_response_packet(packet); - let transmit = self.sut.encapsulate(echo_response, now)?.into_owned(); + let transmit = self + .sut + .encapsulate(echo_response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); return Some(transmit); } @@ -89,7 +95,11 @@ impl SimGateway { global_dns_records.get(name).cloned().into_iter().flatten() }); - let transmit = self.sut.encapsulate(response, now)?.into_owned(); + let transmit = self + .sut + .encapsulate(response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); return Some(transmit); } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index f17d42103..0415513fb 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::{ borrow::Cow, - collections::VecDeque, + // collections::VecDeque, io::{self, IoSliceMut}, net::{IpAddr, SocketAddr}, slice, @@ -84,8 +84,6 @@ pub struct UdpSocket { src_by_dst_cache: HashMap, port: u16, - - buffered_datagrams: VecDeque>, } impl UdpSocket { @@ -97,7 +95,6 @@ impl UdpSocket { port, inner, source_ip_resolver: Box::new(|_| Ok(None)), - buffered_datagrams: VecDeque::new(), src_by_dst_cache: Default::default(), }) } @@ -146,16 +143,6 @@ pub struct DatagramOut<'a> { pub packet: Cow<'a, [u8]>, } -impl<'a> DatagramOut<'a> { - fn into_owned(self) -> DatagramOut<'static> { - DatagramOut { - src: self.src, - dst: self.dst, - packet: Cow::Owned(self.packet.into_owned()), - } - } -} - impl UdpSocket { #[allow(clippy::type_complexity)] pub fn poll_recv_from<'b>( @@ -205,51 +192,16 @@ impl UdpSocket { } } - pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - ready!(self.inner.poll_send_ready(cx))?; // Ensure we are ready to send. - - let Some(transmit) = self.buffered_datagrams.pop_front() else { - break; - }; - - match self.try_send(&transmit) { - Ok(()) => continue, // Try to send another packet. - Err(e) => { - self.buffered_datagrams.push_front(transmit); // Don't lose the packet if we fail. - - if e.kind() == io::ErrorKind::WouldBlock { - continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`. - } - - return Poll::Ready(Err(e)); - } - } - } - - assert!(self.buffered_datagrams.is_empty()); - - Poll::Ready(Ok(())) + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_send_ready(cx) } pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, num_bytes = %datagram.packet.len()); - debug_assert!( - self.buffered_datagrams.len() < 10_000, - "We are not flushing the packets for some reason" - ); + self.try_send(&datagram)?; - match self.try_send(&datagram) { - Ok(()) => Ok(()), - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - tracing::trace!("Buffering packet because socket is busy"); - - self.buffered_datagrams.push_back(datagram.into_owned()); - Ok(()) - } - Err(e) => Err(e), - } + Ok(()) } pub fn try_send(&mut self, transmit: &DatagramOut) -> io::Result<()> {