diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 6fd44868b..9b808eeae 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2006,7 +2006,6 @@ dependencies = [ "itertools 0.13.0", "proptest", "proptest-state-machine", - "quinn-udp", "rand 0.8.5", "rand_core 0.6.4", "rangemap", @@ -4746,8 +4745,8 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.2" -source = "git+https://github.com/quinn-rs/quinn?branch=main#3f489e2eab014ddd04de58e570ba56e9b027f0bc" +version = "0.5.4" +source = "git+https://github.com/quinn-rs/quinn?branch=main#061a74fb6ef67b12f78bc2a3cfc9906e54762eeb" dependencies = [ "libc", "once_cell", @@ -5626,8 +5625,12 @@ dependencies = [ name = "socket-factory" version = "0.1.0" dependencies = [ + "async-trait", + "hickory-proto", + "quinn-udp", "socket2", "tokio", + "tracing", ] [[package]] diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index c55aae6ab..2860e065a 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -19,7 +19,7 @@ use jni::{ }; use phoenix_channel::PhoenixChannel; use secrecy::{Secret, SecretString}; -use socket_factory::SocketFactory; +use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc}; use std::{ net::{Ipv4Addr, Ipv6Addr}, @@ -532,9 +532,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se session.inner.set_tun(Box::new(tun)); } -fn protected_tcp_socket_factory( - callbacks: CallbackHandler, -) -> impl SocketFactory { +fn protected_tcp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory { move |addr| { let socket = socket_factory::tcp(addr)?; callbacks.protect(socket.as_raw_fd())?; @@ -542,9 +540,7 @@ fn protected_tcp_socket_factory( } } -fn protected_udp_socket_factory( - callbacks: CallbackHandler, -) -> impl SocketFactory { +fn protected_udp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory { move |addr| { let socket = socket_factory::udp(addr)?; callbacks.protect(socket.as_raw_fd())?; diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 61a96f752..36d61820f 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -11,7 +11,7 @@ use eventloop::Command; use firezone_tunnel::ClientTunnel; use messages::{IngressMessages, ReplyMessages}; use phoenix_channel::PhoenixChannel; -use socket_factory::SocketFactory; +use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; @@ -36,8 +36,8 @@ pub struct Session { /// Arguments for `connect`, since Clippy said 8 args is too many pub struct ConnectArgs { - pub tcp_socket_factory: Arc>, - pub udp_socket_factory: Arc>, + pub tcp_socket_factory: Arc>, + pub udp_socket_factory: Arc>, pub private_key: StaticSecret, pub callbacks: CB, } diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 93db08f80..c62201166 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -22,13 +22,12 @@ ip_network = { version = "0.4", default-features = false } ip_network_table = { version = "0.2", default-features = false } itertools = { version = "0.13", default-features = false, features = ["use_std"] } proptest = { version = "1", optional = true } -quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } rangemap = "1.5.1" secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } snownet = { workspace = true } -socket-factory = { workspace = true } +socket-factory = { workspace = true, features = ["hickory"] } socket2 = { workspace = true } thiserror = { version = "1.0", default-features = false } tokio = { workspace = true } diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 742a3290d..33111b315 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,8 +1,4 @@ -use crate::{ - device_channel::Device, - dns::DnsQuery, - sockets::{Received, Sockets}, -}; +use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets}; use connlib_shared::messages::DnsServer; use futures::Future; use futures_bounded::FuturesTupleSet; @@ -15,7 +11,7 @@ use hickory_resolver::{ AsyncResolver, TokioHandle, }; use ip_packet::{IpPacket, MutableIpPacket}; -use socket_factory::SocketFactory; +use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::HashMap, io, @@ -25,7 +21,6 @@ use std::{ task::{ready, Context, Poll}, time::{Duration, Instant}, }; -use tokio::net::{TcpSocket, UdpSocket}; const DNS_QUERIES_QUEUE_SIZE: usize = 100; @@ -94,7 +89,7 @@ impl Io { ip4_buffer: &'b mut [u8], ip6_bffer: &'b mut [u8], device_buffer: &'b mut [u8], - ) -> Poll>>>> { + ) -> Poll>>>> { if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) { return Poll::Ready(Ok(Input::DnsResponse(query, response))); } @@ -185,7 +180,11 @@ impl Io { } pub fn send_network(&mut self, transmit: snownet::Transmit) -> io::Result<()> { - self.sockets.send(transmit)?; + self.sockets.send(DatagramOut { + src: transmit.src, + dst: transmit.dst, + packet: transmit.payload, + })?; Ok(()) } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 846b171a8..e1956eb8d 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -13,6 +13,7 @@ use connlib_shared::{ }; use io::Io; use ip_network::{Ipv4Network, Ipv6Network}; +use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::{BTreeSet, HashMap, HashSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -70,8 +71,8 @@ pub struct Tunnel { impl ClientTunnel { pub fn new( private_key: StaticSecret, - tcp_socket_factory: Arc>, - udp_socket_factory: Arc>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, known_hosts: HashMap>, ) -> std::io::Result { Ok(Self { diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 77b20d852..4435b3641 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,29 +1,23 @@ -use core::slice; -use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState}; -use socket_factory::SocketFactory; +use socket_factory::{DatagramIn, DatagramOut, SocketFactory, UdpSocket}; use std::{ - collections::VecDeque, - io::{self, IoSliceMut}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + io, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, task::{ready, Context, Poll}, }; -use tokio::{io::Interest, net::UdpSocket}; -use crate::Result; +const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0); +const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0); #[derive(Default)] pub(crate) struct Sockets { - socket_v4: Option, - socket_v6: Option, + socket_v4: Option, + socket_v6: Option, } impl Sockets { - pub fn rebind( - &mut self, - socket_factory: &dyn SocketFactory, - ) -> io::Result<()> { - let socket_v4 = Socket::ip4(socket_factory); - let socket_v6 = Socket::ip6(socket_factory); + pub fn rebind(&mut self, socket_factory: &dyn SocketFactory) -> io::Result<()> { + let socket_v4 = socket_factory(&SocketAddr::V4(UNSPECIFIED_V4_SOCKET)); + let socket_v6 = socket_factory(&SocketAddr::V6(UNSPECIFIED_V6_SOCKET)); match (socket_v4.as_ref(), socket_v6.as_ref()) { (Err(e), Ok(_)) => { @@ -65,8 +59,8 @@ impl Sockets { Poll::Ready(Ok(())) } - pub fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> { - let socket = match transmit.dst { + pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { + let socket = match datagram.dst { SocketAddr::V4(dst) => self.socket_v4.as_mut().ok_or(io::Error::new( io::ErrorKind::NotConnected, format!("failed send packet to {dst}: no IPv4 socket"), @@ -76,7 +70,7 @@ impl Sockets { format!("failed send packet to {dst}: no IPv6 socket"), ))?, }; - socket.send(transmit)?; + socket.send(datagram)?; Ok(()) } @@ -86,7 +80,7 @@ impl Sockets { ip4_buffer: &'b mut [u8], ip6_buffer: &'b mut [u8], cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { let mut iter = PacketIter::new(); if let Some(Poll::Ready(packets)) = self @@ -133,10 +127,10 @@ impl PacketIter { impl<'a, T4, T6> Iterator for PacketIter where - T4: Iterator>, - T6: Iterator>, + T4: Iterator>, + T6: Iterator>, { - type Item = Received<'a>; + type Item = DatagramIn<'a>; fn next(&mut self) -> Option { if let Some(packet) = self.ip4.as_mut().and_then(|i| i.next()) { @@ -150,160 +144,3 @@ where None } } - -pub struct Received<'a> { - pub local: SocketAddr, - pub from: SocketAddr, - pub packet: &'a [u8], -} - -struct Socket { - state: UdpSocketState, - port: u16, - socket: UdpSocket, - - buffered_transmits: VecDeque>, -} - -impl Socket { - fn ip( - socket_factory: &dyn SocketFactory, - addr: &SocketAddr, - ) -> Result { - let socket = socket_factory(addr)?; - let port = socket.local_addr()?.port(); - - Ok(Socket { - state: UdpSocketState::new(UdpSockRef::from(&socket))?, - port, - socket, - buffered_transmits: VecDeque::new(), - }) - } - - fn ip4(socket_factory: &dyn SocketFactory) -> Result { - Self::ip( - socket_factory, - &SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)), - ) - } - - fn ip6(socket_factory: &dyn SocketFactory) -> Result { - Self::ip( - socket_factory, - &SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)), - ) - } - - #[allow(clippy::type_complexity)] - fn poll_recv_from<'b>( - &self, - buffer: &'b mut [u8], - cx: &mut Context<'_>, - ) -> Poll>>> { - let Socket { - port, - socket, - state, - .. - } = self; - - let bufs = &mut [IoSliceMut::new(buffer)]; - let mut meta = RecvMeta::default(); - - loop { - ready!(socket.poll_recv_ready(cx))?; - - if let Ok(len) = socket.try_io(Interest::READABLE, || { - state.recv((&socket).into(), bufs, slice::from_mut(&mut meta)) - }) { - debug_assert_eq!(len, 1); - - if meta.len == 0 { - continue; - } - - let Some(local_ip) = meta.dst_ip else { - tracing::warn!("Skipping packet without local IP"); - continue; - }; - - let local = SocketAddr::new(local_ip, *port); - - let iter = buffer[..meta.len] - .chunks(meta.stride) - .map(move |packet| Received { - local, - from: meta.addr, - packet, - }) - .inspect(|r| { - tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len()); - }); - - return Poll::Ready(Ok(iter)); - } - } - } - - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - ready!(self.socket.poll_send_ready(cx))?; // Ensure we are ready to send. - - let Some(transmit) = self.buffered_transmits.pop_front() else { - break; - }; - - match self.try_send(&transmit) { - Ok(()) => continue, // Try to send another packet. - Err(e) => { - self.buffered_transmits.push_front(transmit); // Don't lose the packet if we fail. - - if e.kind() == io::ErrorKind::WouldBlock { - continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`. - } - - return Poll::Ready(Err(e)); - } - } - } - - assert!(self.buffered_transmits.is_empty()); - - Poll::Ready(Ok(())) - } - - fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> { - tracing::trace!(target: "wire::net::send", src = ?transmit.src, dst = %transmit.dst, num_bytes = %transmit.payload.len()); - - debug_assert!( - self.buffered_transmits.len() < 10_000, - "We are not flushing the packets for some reason" - ); - - match self.try_send(&transmit) { - Ok(()) => Ok(()), - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - tracing::trace!("Buffering packet because socket is busy"); - - self.buffered_transmits.push_back(transmit.into_owned()); - Ok(()) - } - Err(e) => Err(e), - } - } - - fn try_send(&self, transmit: &snownet::Transmit) -> io::Result<()> { - let transmit = quinn_udp::Transmit { - destination: transmit.dst, - ecn: None, - contents: &transmit.payload, - segment_size: None, - src_ip: transmit.src.map(|s| s.ip()), - }; - - self.socket.try_io(Interest::WRITABLE, || { - self.state.send((&self.socket).into(), &transmit) - }) - } -} diff --git a/rust/headless-client/src/linux.rs b/rust/headless-client/src/linux.rs index 7b9267739..22ca47393 100644 --- a/rust/headless-client/src/linux.rs +++ b/rust/headless-client/src/linux.rs @@ -4,6 +4,7 @@ use super::TOKEN_ENV_KEY; use anyhow::{bail, Result}; use firezone_bin_shared::FIREZONE_MARK; use nix::sys::socket::{setsockopt, sockopt}; +use socket_factory::{TcpSocket, UdpSocket}; use std::{ io, net::SocketAddr, @@ -15,13 +16,13 @@ use std::{ const ROOT_GROUP: u32 = 0; const ROOT_USER: u32 = 0; -pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result { +pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result { let socket = socket_factory::tcp(socket_addr)?; setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; Ok(socket) } -pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result { +pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result { let socket = socket_factory::udp(socket_addr)?; setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; Ok(socket) diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index a44f0918f..971af921d 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -17,7 +17,7 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat}; use rand_core::{OsRng, RngCore}; use secrecy::{ExposeSecret, Secret}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use socket_factory::SocketFactory; +use socket_factory::{SocketFactory, TcpSocket}; use std::task::{Context, Poll, Waker}; use tokio::net::TcpStream; use tokio_tungstenite::client_async_tls; @@ -35,7 +35,7 @@ pub struct PhoenixChannel { waker: Option, pending_messages: VecDeque, next_request_id: Arc, - socket_factory: Arc>, + socket_factory: Arc>, heartbeat: Heartbeat, @@ -67,7 +67,7 @@ impl State { fn connect( url: Secret, user_agent: String, - socket_factory: Arc>, + socket_factory: Arc>, ) -> Self { Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed()) } @@ -76,7 +76,7 @@ impl State { async fn create_and_connect_websocket( url: Secret, user_agent: String, - socket_factory: Arc>, + socket_factory: Arc>, ) -> Result>, InternalError> { let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?; @@ -89,7 +89,7 @@ async fn create_and_connect_websocket( async fn make_socket( url: &Url, - socket_factory: &dyn SocketFactory, + socket_factory: &dyn SocketFactory, ) -> Result { let port = url .port_or_known_default() @@ -229,7 +229,7 @@ where login: &'static str, init_req: TInitReq, reconnect_backoff: ExponentialBackoff, - socket_factory: Arc>, + socket_factory: Arc>, ) -> io::Result { let next_request_id = Arc::new(AtomicU64::new(0)); diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml index 13712112f..5522c49fe 100644 --- a/rust/socket-factory/Cargo.toml +++ b/rust/socket-factory/Cargo.toml @@ -4,5 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] +async-trait = { version = "0.1", optional = true } +hickory-proto = { workspace = true, optional = true } +quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" } socket2 = { workspace = true } tokio = { version = "1.38", features = ["net"] } +tracing = "0.1" + +[features] +hickory = ["dep:hickory-proto", "dep:async-trait"] diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs index db18ef752..225895ed3 100644 --- a/rust/socket-factory/src/lib.rs +++ b/rust/socket-factory/src/lib.rs @@ -1,15 +1,20 @@ -use std::net::SocketAddr; +use std::{ + borrow::Cow, + collections::VecDeque, + io::{self, IoSliceMut}, + net::SocketAddr, + slice, + task::{ready, Context, Poll}, +}; use socket2::SockAddr; +use tokio::io::Interest; -pub trait SocketFactory: Fn(&SocketAddr) -> std::io::Result + Send + Sync + 'static {} +pub trait SocketFactory: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} -impl SocketFactory for F where - F: Fn(&SocketAddr) -> std::io::Result + Send + Sync + 'static -{ -} +impl SocketFactory for F where F: Fn(&SocketAddr) -> io::Result + Send + Sync + 'static {} -pub fn tcp(addr: &SocketAddr) -> std::io::Result { +pub fn tcp(addr: &SocketAddr) -> io::Result { let socket = match addr { SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?, SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?, @@ -17,9 +22,10 @@ pub fn tcp(addr: &SocketAddr) -> std::io::Result { socket.set_nodelay(true)?; - Ok(socket) + Ok(TcpSocket { inner: socket }) } -pub fn udp(addr: &SocketAddr) -> std::io::Result { + +pub fn udp(addr: &SocketAddr) -> io::Result { let addr: SockAddr = (*addr).into(); let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?; @@ -31,5 +37,263 @@ pub fn udp(addr: &SocketAddr) -> std::io::Result { socket.set_nonblocking(true)?; socket.bind(&addr)?; - std::net::UdpSocket::from(socket).try_into() + let socket = std::net::UdpSocket::from(socket); + let socket = tokio::net::UdpSocket::try_from(socket)?; + let socket = UdpSocket::new(socket)?; + + Ok(socket) +} + +pub struct TcpSocket { + inner: tokio::net::TcpSocket, +} + +impl TcpSocket { + pub async fn connect(self, addr: SocketAddr) -> io::Result { + self.inner.connect(addr).await + } +} + +#[cfg(unix)] +impl std::os::fd::AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> std::os::fd::RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(unix)] +impl std::os::fd::AsFd for TcpSocket { + fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> { + self.inner.as_fd() + } +} + +pub struct UdpSocket { + inner: tokio::net::UdpSocket, + state: quinn_udp::UdpSocketState, + + port: u16, + + buffered_datagrams: VecDeque>, +} + +impl UdpSocket { + fn new(inner: tokio::net::UdpSocket) -> io::Result { + let port = inner.local_addr()?.port(); + + Ok(UdpSocket { + state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?, + port, + inner, + buffered_datagrams: VecDeque::new(), + }) + } +} + +#[cfg(unix)] +impl std::os::fd::AsRawFd for UdpSocket { + fn as_raw_fd(&self) -> std::os::fd::RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(unix)] +impl std::os::fd::AsFd for UdpSocket { + fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> { + self.inner.as_fd() + } +} + +/// An inbound UDP datagram. +pub struct DatagramIn<'a> { + pub local: SocketAddr, + pub from: SocketAddr, + pub packet: &'a [u8], +} + +/// An outbound UDP datagram. +pub struct DatagramOut<'a> { + pub src: Option, + pub dst: SocketAddr, + pub packet: Cow<'a, [u8]>, +} + +impl<'a> DatagramOut<'a> { + fn into_owned(self) -> DatagramOut<'static> { + DatagramOut { + src: self.src, + dst: self.dst, + packet: Cow::Owned(self.packet.into_owned()), + } + } +} + +impl UdpSocket { + #[allow(clippy::type_complexity)] + pub fn poll_recv_from<'b>( + &self, + buffer: &'b mut [u8], + cx: &mut Context<'_>, + ) -> Poll>>> { + let Self { + port, inner, state, .. + } = self; + + let bufs = &mut [IoSliceMut::new(buffer)]; + let mut meta = quinn_udp::RecvMeta::default(); + + loop { + ready!(inner.poll_recv_ready(cx))?; + + if let Ok(len) = inner.try_io(Interest::READABLE, || { + state.recv((&inner).into(), bufs, slice::from_mut(&mut meta)) + }) { + debug_assert_eq!(len, 1); + + if meta.len == 0 { + continue; + } + + let Some(local_ip) = meta.dst_ip else { + tracing::warn!("Skipping packet without local IP"); + continue; + }; + + let local = SocketAddr::new(local_ip, *port); + + let iter = buffer[..meta.len] + .chunks(meta.stride) + .map(move |packet| DatagramIn { + local, + from: meta.addr, + packet, + }) + .inspect(|r| { + tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len()); + }); + + return Poll::Ready(Ok(iter)); + } + } + } + + pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + ready!(self.inner.poll_send_ready(cx))?; // Ensure we are ready to send. + + let Some(transmit) = self.buffered_datagrams.pop_front() else { + break; + }; + + match self.try_send(&transmit) { + Ok(()) => continue, // Try to send another packet. + Err(e) => { + self.buffered_datagrams.push_front(transmit); // Don't lose the packet if we fail. + + if e.kind() == io::ErrorKind::WouldBlock { + continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`. + } + + return Poll::Ready(Err(e)); + } + } + } + + assert!(self.buffered_datagrams.is_empty()); + + Poll::Ready(Ok(())) + } + + pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> { + tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, num_bytes = %datagram.packet.len()); + + debug_assert!( + self.buffered_datagrams.len() < 10_000, + "We are not flushing the packets for some reason" + ); + + match self.try_send(&datagram) { + Ok(()) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + tracing::trace!("Buffering packet because socket is busy"); + + self.buffered_datagrams.push_back(datagram.into_owned()); + Ok(()) + } + Err(e) => Err(e), + } + } + + pub fn try_send(&self, transmit: &DatagramOut) -> io::Result<()> { + let transmit = quinn_udp::Transmit { + destination: transmit.dst, + ecn: None, + contents: &transmit.packet, + segment_size: None, + src_ip: transmit.src.map(|s| s.ip()), + }; + + self.inner.try_io(Interest::WRITABLE, || { + self.state.send((&self.inner).into(), &transmit) + }) + } +} + +#[cfg(feature = "hickory")] +mod hickory { + use super::*; + use hickory_proto::{ + udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime, + }; + use tokio::net::UdpSocket as TokioUdpSocket; + + #[async_trait::async_trait] + impl UdpSocketTrait for crate::UdpSocket { + /// setups up a "client" udp connection that will only receive packets from the associated address + async fn connect(addr: SocketAddr) -> io::Result { + let inner = ::connect(addr).await?; + let socket = Self::new(inner)?; + + Ok(socket) + } + + /// same as connect, but binds to the specified local address for sending address + async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result { + let inner = + ::connect_with_bind(addr, bind_addr).await?; + let socket = Self::new(inner)?; + + Ok(socket) + } + + /// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything) + async fn bind(addr: SocketAddr) -> io::Result { + let inner = ::bind(addr).await?; + let socket = Self::new(inner)?; + + Ok(socket) + } + } + + #[cfg(feature = "hickory")] + impl DnsUdpSocketTrait for crate::UdpSocket { + type Time = TokioTime; + + fn poll_recv_from( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + ::poll_recv_from(&self.inner, cx, buf) + } + + fn poll_send_to( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: SocketAddr, + ) -> Poll> { + ::poll_send_to(&self.inner, cx, buf, target) + } + } }