diff --git a/rust/bin-shared/src/tun_device_manager.rs b/rust/bin-shared/src/tun_device_manager.rs index ef2305883..5619f0a4f 100644 --- a/rust/bin-shared/src/tun_device_manager.rs +++ b/rust/bin-shared/src/tun_device_manager.rs @@ -59,6 +59,10 @@ mod tests { ))) .unwrap(); + std::future::poll_fn(|cx| socket.poll_send_ready(cx)) + .await + .unwrap(); + // Send a STUN request. socket .send(DatagramOut { @@ -70,11 +74,6 @@ mod tests { }) .unwrap(); - // First send seems to always result as would block - std::future::poll_fn(|cx| socket.poll_flush(cx)) - .await - .unwrap(); - let task = std::future::poll_fn(|cx| { let mut buf = [0u8; 1000]; let result = std::task::ready!(socket.poll_recv_from(&mut buf, cx)); diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 457de12b3..ff2d60ed4 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -15,6 +15,7 @@ use firezone_tunnel::ClientTunnel; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ collections::{BTreeMap, BTreeSet}, + io, net::IpAddr, task::{Context, Poll}, }; @@ -91,6 +92,9 @@ where self.handle_tunnel_event(event); continue; } + Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } Poll::Ready(Err(e)) => { tracing::warn!("Tunnel error: {e}"); continue; diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index c7d529051..4fa3ffbd5 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -3,6 +3,7 @@ use crate::{ node::{CandidateEvent, SessionId, Transmit}, ringbuffer::RingBuffer, utils::earliest, + EncryptedPacket, }; use ::backoff::backoff::Backoff; use bytecodec::{DecodeExt as _, EncodeExt as _}; @@ -721,29 +722,23 @@ impl Allocation { ); } - /// Encodes the packet contained in the given buffer into a [`Transmit`]. - /// - /// This function assumes that the first 4 bytes of `buffer` have been reserved for the header of the channel-data message. - pub fn encode_to_borrowed_transmit<'b>( + pub fn encode_to_encrypted_packet( &self, peer: SocketAddr, - buffer: &'b mut [u8], + buffer: &mut [u8], now: Instant, - ) -> Option> { + ) -> Option { let buffer_len = buffer.len(); let packet_len = buffer_len - 4; let channel_number = self.channel_bindings.connected_channel_to_peer(peer, now)?; - let total_length = crate::channel_data::encode_header_to_slice( - &mut buffer[..4], - channel_number, - packet_len, - ); + crate::channel_data::encode_header_to_slice(&mut buffer[..4], channel_number, packet_len); - Some(Transmit { + Some(EncryptedPacket { src: None, dst: self.active_socket?, - payload: Cow::Borrowed(&buffer[..total_length]), + packet_start: 0, + packet_len: buffer_len, }) } diff --git a/rust/connlib/snownet/src/lib.rs b/rust/connlib/snownet/src/lib.rs index b944a89b8..454b9d715 100644 --- a/rust/connlib/snownet/src/lib.rs +++ b/rust/connlib/snownet/src/lib.rs @@ -11,7 +11,7 @@ mod utils; pub use allocation::RelaySocket; pub use node::{ - Answer, Client, ClientNode, Credentials, Error, Event, Node, Offer, Server, ServerNode, - Transmit, HANDSHAKE_TIMEOUT, + Answer, Client, ClientNode, Credentials, EncryptBuffer, EncryptedPacket, Error, Event, Node, + Offer, Server, ServerNode, Transmit, HANDSHAKE_TIMEOUT, }; pub use stats::{ConnectionStats, NodeStats}; diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 5683f3f2c..ab418e38a 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -287,14 +287,14 @@ where /// - `Ok(None)` if the packet was handled internally, for example, a response from a TURN server. /// - `Ok(Some)` if the packet was an encrypted wireguard packet from a peer. /// The `Option` contains the connection on which the packet was decrypted. - pub fn decapsulate<'s>( + pub fn decapsulate<'b>( &mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant, - buffer: &'s mut [u8], - ) -> Result)>, Error> { + buffer: &'b mut [u8], + ) -> Result)>, Error> { self.add_local_as_host_candidate(local)?; let (from, packet, relayed) = match self.allocations_try_handle(from, local, packet, now) { @@ -325,12 +325,13 @@ where /// Wireguard is an IP tunnel, so we "enforce" that only IP packets are sent through it. /// We say "enforce" an [`IpPacket`] can be created from an (almost) arbitrary byte buffer at virtually no cost. /// Nevertheless, using [`IpPacket`] in our API has good documentation value. - pub fn encapsulate<'s>( - &'s mut self, + pub fn encapsulate( + &mut self, connection: TId, packet: IpPacket<'_>, now: Instant, - ) -> Result>, Error> { + buffer: &mut EncryptBuffer, + ) -> Result, Error> { let conn = self .connections .get_established_mut(&connection) @@ -341,7 +342,7 @@ where // Encode the packet with an offset of 4 bytes, in case we need to wrap it in a channel-data message. let Some(packet_len) = conn - .encapsulate(packet.packet(), &mut self.buffer[4..], now)? + .encapsulate(packet.packet(), &mut buffer.inner[4..], now)? .map(|p| p.len()) // Mapping to len() here terminate the mutable borrow of buffer, allowing re-borrowing further down. else { @@ -355,30 +356,26 @@ where PeerSocket::Direct { dest: remote, source, - } => { - // Re-borrow the actual packet. - let packet = &self.buffer[packet_start..packet_end]; - - Ok(Some(Transmit { - src: Some(source), - dst: remote, - payload: Cow::Borrowed(packet), - })) - } + } => Ok(Some(EncryptedPacket { + src: Some(source), + dst: remote, + packet_start, + packet_len, + })), PeerSocket::Relay { relay, dest: peer } => { let Some(allocation) = self.allocations.get(&relay) else { tracing::warn!(%relay, "No allocation"); return Ok(None); }; - let packet = &mut self.buffer[..packet_end]; + let packet = &mut buffer.inner[..packet_end]; - let Some(transmit) = allocation.encode_to_borrowed_transmit(peer, packet, now) + let Some(enc_packet) = allocation.encode_to_encrypted_packet(peer, packet, now) else { tracing::warn!(%peer, "No channel"); return Ok(None); }; - Ok(Some(transmit)) + Ok(Some(enc_packet)) } } } @@ -1240,6 +1237,42 @@ pub enum Event { ConnectionClosed(TId), } +pub struct EncryptBuffer { + inner: Vec, +} + +impl EncryptBuffer { + pub fn new(len: usize) -> Self { + Self { + inner: vec![0u8; len], + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct EncryptedPacket { + pub(crate) src: Option, + pub(crate) dst: SocketAddr, + pub(crate) packet_start: usize, + pub(crate) packet_len: usize, +} + +impl EncryptedPacket { + pub fn to_transmit(self, buf: &EncryptBuffer) -> Transmit<'_> { + Transmit { + src: self.src, + dst: self.dst, + payload: Cow::Borrowed( + &buf.inner[self.packet_start..(self.packet_start + self.packet_len)], + ), + } + } + + pub fn dst(&self) -> SocketAddr { + self.dst + } +} + #[derive(Clone, PartialEq, PartialOrd, Eq, Ord)] pub struct Transmit<'a> { /// The local interface from which this packet should be sent. diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 616510af7..1916f600a 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -22,7 +22,7 @@ use crate::{ClientEvent, ClientTunnel, Tun}; use domain::base::Message; use lru::LruCache; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{ClientNode, RelaySocket, Transmit}; +use snownet::{ClientNode, EncryptBuffer, RelaySocket, Transmit}; use std::borrow::Cow; use std::collections::hash_map::Entry; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; @@ -419,11 +419,12 @@ impl ClientState { ) } - pub(crate) fn encapsulate<'s>( - &'s mut self, + pub(crate) fn encapsulate( + &mut self, packet: MutableIpPacket<'_>, now: Instant, - ) -> Option> { + buffer: &mut EncryptBuffer, + ) -> Option { let (packet, dst) = match self.try_handle_dns_query(packet, now) { Ok(response) => { self.buffered_packets.push_back(response?.to_owned()); @@ -469,7 +470,7 @@ impl ClientState { let transmit = self .node - .encapsulate(gid, packet.as_immutable(), now) + .encapsulate(gid, packet.as_immutable(), now, buffer) .inspect_err(|e| tracing::debug!(%gid, "Failed to encapsulate: {e}")) .ok()??; diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 2a891ec98..61f844176 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -13,7 +13,7 @@ use connlib_shared::{DomainName, StaticSecret}; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::{IpPacket, MutableIpPacket}; use secrecy::{ExposeSecret as _, Secret}; -use snownet::{RelaySocket, ServerNode}; +use snownet::{EncryptBuffer, RelaySocket, ServerNode}; use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::{Duration, Instant}; @@ -155,11 +155,12 @@ impl GatewayState { self.node.public_key() } - pub(crate) fn encapsulate<'s>( - &'s mut self, + pub(crate) fn encapsulate( + &mut self, packet: MutableIpPacket<'_>, now: Instant, - ) -> Option> { + buffer: &mut EncryptBuffer, + ) -> Option { let dst = packet.destination(); if !is_client(dst) { @@ -180,7 +181,7 @@ impl GatewayState { let transmit = self .node - .encapsulate(peer.id(), packet.as_immutable(), now) + .encapsulate(peer.id(), packet.as_immutable(), now, buffer) .inspect_err(|e| tracing::debug!(%cid, "Failed to encapsulate: {e}")) .ok()??; diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 30f63c6c1..f11dbc27c 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,7 @@ use crate::{device_channel::Device, sockets::Sockets, BUF_SIZE}; use futures_util::FutureExt as _; use ip_packet::{IpPacket, MutableIpPacket}; +use snownet::{EncryptBuffer, EncryptedPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ io, @@ -18,6 +19,7 @@ pub struct Io { device: Device, /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, + unwritten_packet: Option, _tcp_socket_factory: Arc>, udp_socket_factory: Arc>, @@ -48,6 +50,7 @@ impl Io { sockets, _tcp_socket_factory: tcp_socket_factory, udp_socket_factory, + unwritten_packet: None, } } @@ -61,13 +64,14 @@ impl Io { ip4_buffer: &'b1 mut [u8], ip6_bffer: &'b1 mut [u8], device_buffer: &'b2 mut [u8], + encrypt_buffer: &EncryptBuffer, ) -> Poll>>>> { + ready!(self.poll_send_unwritten(cx, encrypt_buffer)?); + if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size)))); } - ready!(self.sockets.poll_flush(cx))?; - if let Poll::Ready(packet) = self.device.poll_read(device_buffer, cx)? { return Poll::Ready(Ok(Input::Device(packet))); } @@ -84,6 +88,23 @@ impl Io { Poll::Pending } + fn poll_send_unwritten( + &mut self, + cx: &mut Context<'_>, + buf: &EncryptBuffer, + ) -> Poll> { + ready!(self.sockets.poll_send_ready(cx))?; + + // If the `unwritten_packet` is set, `EncryptBuffer` is still holding a packet that we need so send. + let Some(unwritten_packet) = self.unwritten_packet.take() else { + return Poll::Ready(Ok(())); + }; + + self.send_encrypted_packet(unwritten_packet, buf)?; + + Poll::Ready(Ok(())) + } + pub fn device_mut(&mut self) -> &mut Device { &mut self.device } @@ -114,6 +135,28 @@ impl Io { Ok(()) } + pub fn send_encrypted_packet( + &mut self, + packet: EncryptedPacket, + buf: &EncryptBuffer, + ) -> io::Result<()> { + let transmit = packet.to_transmit(buf); + let res = self.send_network(transmit); + + if res + .as_ref() + .err() + .is_some_and(|e| e.kind() == io::ErrorKind::WouldBlock) + { + tracing::debug!(dst = %packet.dst(), "Socket busy"); + self.unwritten_packet = Some(packet); + } + + res?; + + Ok(()) + } + pub fn send_device(&self, packet: IpPacket<'_>) -> io::Result<()> { self.device.write(packet)?; diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 8ea3f70e8..3252c5a92 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -64,6 +64,7 @@ pub type ClientTunnel = Tunnel; pub use client::ClientState; pub use gateway::{GatewayState, IPV4_PEERS, IPV6_PEERS}; +use snownet::EncryptBuffer; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. /// @@ -81,8 +82,12 @@ pub struct Tunnel { ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, - /// Buffer for processing a single IP packet. - packet_buffer: Box<[u8; BUF_SIZE]>, + /// Buffer for reading a single IP packet. + device_read_buf: Box<[u8; BUF_SIZE]>, + /// Buffer for decryping a single packet. + decrypt_buf: Box<[u8; BUF_SIZE]>, + /// Buffer for encrypting a single packet. + encrypt_buf: EncryptBuffer, } impl ClientTunnel { @@ -95,9 +100,11 @@ impl ClientTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: ClientState::new(private_key, known_hosts, rand::random()), - packet_buffer: Box::new([0u8; BUF_SIZE]), + device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + encrypt_buf: EncryptBuffer::new(BUF_SIZE), + decrypt_buf: Box::new([0u8; BUF_SIZE]), } } @@ -132,18 +139,23 @@ impl ClientTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), + &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout); continue; } Poll::Ready(io::Input::Device(packet)) => { - let Some(transmit) = self.role_state.encapsulate(packet, Instant::now()) else { + let Some(enc_packet) = + self.role_state + .encapsulate(packet, Instant::now(), &mut self.encrypt_buf) + else { continue; }; - self.io.send_network(transmit)?; + self.io + .send_encrypted_packet(enc_packet, &self.encrypt_buf)?; continue; } @@ -154,7 +166,7 @@ impl ClientTunnel { received.from, received.packet, std::time::Instant::now(), - self.packet_buffer.as_mut(), + self.decrypt_buf.as_mut(), ) else { continue; }; @@ -185,9 +197,11 @@ impl GatewayTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: GatewayState::new(private_key, rand::random()), - packet_buffer: Box::new([0u8; BUF_SIZE]), + device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + encrypt_buf: EncryptBuffer::new(BUF_SIZE), + decrypt_buf: Box::new([0u8; BUF_SIZE]), } } @@ -217,21 +231,24 @@ impl GatewayTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), + &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout, Utc::now()); continue; } Poll::Ready(io::Input::Device(packet)) => { - let Some(transmit) = self - .role_state - .encapsulate(packet, std::time::Instant::now()) - else { + let Some(enc_packet) = self.role_state.encapsulate( + packet, + std::time::Instant::now(), + &mut self.encrypt_buf, + ) else { continue; }; - self.io.send_network(transmit)?; + self.io + .send_encrypted_packet(enc_packet, &self.encrypt_buf)?; continue; } @@ -242,7 +259,7 @@ impl GatewayTunnel { received.from, received.packet, std::time::Instant::now(), - self.packet_buffer.as_mut(), + self.device_read_buf.as_mut(), ) else { continue; }; diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 7d1326571..87efb6744 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -45,16 +45,13 @@ impl Sockets { Poll::Ready(()) } - /// Flushes all buffered data on the sockets. - /// - /// Returns `Ready` if the socket is able to accept more data. - pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(socket) = self.socket_v4.as_mut() { - ready!(socket.poll_flush(cx))?; + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + if let Some(socket) = self.socket_v4.as_ref() { + ready!(socket.poll_send_ready(cx))?; } - if let Some(socket) = self.socket_v6.as_mut() { - ready!(socket.poll_flush(cx))?; + if let Some(socket) = self.socket_v6.as_ref() { + ready!(socket.poll_send_ready(cx))?; } Poll::Ready(Ok(())) diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 9141451e0..2bd00b46b 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -28,7 +28,7 @@ use ip_packet::{IpPacket, MutableIpPacket, Packet as _}; use itertools::Itertools as _; use prop::collection; use proptest::prelude::*; -use snownet::Transmit; +use snownet::{EncryptBuffer, Transmit}; use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}, mem, @@ -57,6 +57,7 @@ pub(crate) struct SimClient { pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket<'static>>, buffer: Vec, + enc_buffer: EncryptBuffer, } impl SimClient { @@ -71,6 +72,7 @@ impl SimClient { sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), buffer: vec![0u8; (1 << 16) - 1], + enc_buffer: EncryptBuffer::new((1 << 16) - 1), } } @@ -147,7 +149,12 @@ impl SimClient { } } - Some(self.sut.encapsulate(packet, now)?.into_owned()) + Some( + self.sut + .encapsulate(packet, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(), + ) } pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index f7797fe9d..c05c01c05 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -11,7 +11,7 @@ use connlib_shared::{ }; use ip_packet::IpPacket; use proptest::prelude::*; -use snownet::Transmit; +use snownet::{EncryptBuffer, Transmit}; use std::{ collections::{BTreeMap, BTreeSet}, net::IpAddr, @@ -27,6 +27,7 @@ pub(crate) struct SimGateway { pub(crate) received_icmp_requests: BTreeMap>, buffer: Vec, + enc_buffer: EncryptBuffer, } impl SimGateway { @@ -36,6 +37,7 @@ impl SimGateway { sut, received_icmp_requests: Default::default(), buffer: vec![0u8; (1 << 16) - 1], + enc_buffer: EncryptBuffer::new((1 << 16) - 1), } } @@ -78,7 +80,11 @@ impl SimGateway { self.received_icmp_requests.insert(payload, packet.clone()); let echo_response = ip_packet::make::icmp_response_packet(packet); - let transmit = self.sut.encapsulate(echo_response, now)?.into_owned(); + let transmit = self + .sut + .encapsulate(echo_response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); return Some(transmit); } @@ -89,7 +95,11 @@ impl SimGateway { global_dns_records.get(name).cloned().into_iter().flatten() }); - let transmit = self.sut.encapsulate(response, now)?.into_owned(); + let transmit = self + .sut + .encapsulate(response, now, &mut self.enc_buffer)? + .to_transmit(&self.enc_buffer) + .into_owned(); return Some(transmit); } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index f17d42103..0415513fb 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::{ borrow::Cow, - collections::VecDeque, + // collections::VecDeque, io::{self, IoSliceMut}, net::{IpAddr, SocketAddr}, slice, @@ -84,8 +84,6 @@ pub struct UdpSocket { src_by_dst_cache: HashMap, port: u16, - - buffered_datagrams: VecDeque>, } impl UdpSocket { @@ -97,7 +95,6 @@ impl UdpSocket { port, inner, source_ip_resolver: Box::new(|_| Ok(None)), - buffered_datagrams: VecDeque::new(), src_by_dst_cache: Default::default(), }) } @@ -146,16 +143,6 @@ pub struct DatagramOut<'a> { pub packet: Cow<'a, [u8]>, } -impl<'a> DatagramOut<'a> { - fn into_owned(self) -> DatagramOut<'static> { - DatagramOut { - src: self.src, - dst: self.dst, - packet: Cow::Owned(self.packet.into_owned()), - } - } -} - impl UdpSocket { #[allow(clippy::type_complexity)] pub fn poll_recv_from<'b>( @@ -205,51 +192,16 @@ impl UdpSocket { } } - pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - ready!(self.inner.poll_send_ready(cx))?; // Ensure we are ready to send. - - let Some(transmit) = self.buffered_datagrams.pop_front() else { - break; - }; - - match self.try_send(&transmit) { - Ok(()) => continue, // Try to send another packet. - Err(e) => { - self.buffered_datagrams.push_front(transmit); // Don't lose the packet if we fail. - - if e.kind() == io::ErrorKind::WouldBlock { - continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`. - } - - return Poll::Ready(Err(e)); - } - } - } - - assert!(self.buffered_datagrams.is_empty()); - - Poll::Ready(Ok(())) + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_send_ready(cx) } pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, num_bytes = %datagram.packet.len()); - debug_assert!( - self.buffered_datagrams.len() < 10_000, - "We are not flushing the packets for some reason" - ); + self.try_send(&datagram)?; - match self.try_send(&datagram) { - Ok(()) => Ok(()), - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - tracing::trace!("Buffering packet because socket is busy"); - - self.buffered_datagrams.push_back(datagram.into_owned()); - Ok(()) - } - Err(e) => Err(e), - } + Ok(()) } pub fn try_send(&mut self, transmit: &DatagramOut) -> io::Result<()> {