refactor(connlib): encapsulate UDP and TCP sockets (#6028)

As part of debugging full-route tunneling on Windows, we discovered that
we need to always explicitly choose the interface through which we want
to send packets, otherwise Windows may cause a routing loop by routing
our packets back into the TUN device.

We already have a `SocketFactory` abstraction in `connlib` that is used
by each platform to customise the setup of each socket to prevent
routing loops.

So far, this abstraction directly returns tokio sockets which don't
allow us to intercept the actual sending of packets. For some of our
traffic, i.e. the UDP packets exchanged with relays, we don't specify a
source address. To make full-route work on Windows, we need to intercept
these packets and explicitly set the source address.

To achieve that, we introduce dedicated `TcpSocket` and `UdpSocket`
structs within `socket-factory`. With this in place, we will be able to
add Windows-conditional code to looks up and sets the source address of
outgoing UDP packets. For TCP sockets, the lookup will happen prior to
connecting to the address and used to bind to the correct interface.

Related: #2667.
Related: #5955.
This commit is contained in:
Thomas Eizinger
2024-07-25 14:28:46 +10:00
committed by GitHub
parent b2298392e6
commit 59014a9622
11 changed files with 331 additions and 224 deletions

9
rust/Cargo.lock generated
View File

@@ -2006,7 +2006,6 @@ dependencies = [
"itertools 0.13.0",
"proptest",
"proptest-state-machine",
"quinn-udp",
"rand 0.8.5",
"rand_core 0.6.4",
"rangemap",
@@ -4746,8 +4745,8 @@ dependencies = [
[[package]]
name = "quinn-udp"
version = "0.5.2"
source = "git+https://github.com/quinn-rs/quinn?branch=main#3f489e2eab014ddd04de58e570ba56e9b027f0bc"
version = "0.5.4"
source = "git+https://github.com/quinn-rs/quinn?branch=main#061a74fb6ef67b12f78bc2a3cfc9906e54762eeb"
dependencies = [
"libc",
"once_cell",
@@ -5626,8 +5625,12 @@ dependencies = [
name = "socket-factory"
version = "0.1.0"
dependencies = [
"async-trait",
"hickory-proto",
"quinn-udp",
"socket2",
"tokio",
"tracing",
]
[[package]]

View File

@@ -19,7 +19,7 @@ use jni::{
};
use phoenix_channel::PhoenixChannel;
use secrecy::{Secret, SecretString};
use socket_factory::SocketFactory;
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc};
use std::{
net::{Ipv4Addr, Ipv6Addr},
@@ -532,9 +532,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
session.inner.set_tun(Box::new(tun));
}
fn protected_tcp_socket_factory(
callbacks: CallbackHandler,
) -> impl SocketFactory<tokio::net::TcpSocket> {
fn protected_tcp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory<TcpSocket> {
move |addr| {
let socket = socket_factory::tcp(addr)?;
callbacks.protect(socket.as_raw_fd())?;
@@ -542,9 +540,7 @@ fn protected_tcp_socket_factory(
}
}
fn protected_udp_socket_factory(
callbacks: CallbackHandler,
) -> impl SocketFactory<tokio::net::UdpSocket> {
fn protected_udp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory<UdpSocket> {
move |addr| {
let socket = socket_factory::udp(addr)?;
callbacks.protect(socket.as_raw_fd())?;

View File

@@ -11,7 +11,7 @@ use eventloop::Command;
use firezone_tunnel::ClientTunnel;
use messages::{IngressMessages, ReplyMessages};
use phoenix_channel::PhoenixChannel;
use socket_factory::SocketFactory;
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
@@ -36,8 +36,8 @@ pub struct Session {
/// Arguments for `connect`, since Clippy said 8 args is too many
pub struct ConnectArgs<CB> {
pub tcp_socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
pub udp_socket_factory: Arc<dyn SocketFactory<tokio::net::UdpSocket>>,
pub tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
pub udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
pub private_key: StaticSecret,
pub callbacks: CB,
}

View File

@@ -22,13 +22,12 @@ ip_network = { version = "0.4", default-features = false }
ip_network_table = { version = "0.2", default-features = false }
itertools = { version = "0.13", default-features = false, features = ["use_std"] }
proptest = { version = "1", optional = true }
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" }
rand_core = { version = "0.6", default-features = false, features = ["getrandom"] }
rangemap = "1.5.1"
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
snownet = { workspace = true }
socket-factory = { workspace = true }
socket-factory = { workspace = true, features = ["hickory"] }
socket2 = { workspace = true }
thiserror = { version = "1.0", default-features = false }
tokio = { workspace = true }

View File

@@ -1,8 +1,4 @@
use crate::{
device_channel::Device,
dns::DnsQuery,
sockets::{Received, Sockets},
};
use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets};
use connlib_shared::messages::DnsServer;
use futures::Future;
use futures_bounded::FuturesTupleSet;
@@ -15,7 +11,7 @@ use hickory_resolver::{
AsyncResolver, TokioHandle,
};
use ip_packet::{IpPacket, MutableIpPacket};
use socket_factory::SocketFactory;
use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::HashMap,
io,
@@ -25,7 +21,6 @@ use std::{
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::net::{TcpSocket, UdpSocket};
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
@@ -94,7 +89,7 @@ impl Io {
ip4_buffer: &'b mut [u8],
ip6_bffer: &'b mut [u8],
device_buffer: &'b mut [u8],
) -> Poll<io::Result<Input<'b, impl Iterator<Item = Received<'b>>>>> {
) -> Poll<io::Result<Input<'b, impl Iterator<Item = DatagramIn<'b>>>>> {
if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) {
return Poll::Ready(Ok(Input::DnsResponse(query, response)));
}
@@ -185,7 +180,11 @@ impl Io {
}
pub fn send_network(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
self.sockets.send(transmit)?;
self.sockets.send(DatagramOut {
src: transmit.src,
dst: transmit.dst,
packet: transmit.payload,
})?;
Ok(())
}

View File

@@ -13,6 +13,7 @@ use connlib_shared::{
};
use io::Io;
use ip_network::{Ipv4Network, Ipv6Network};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::{BTreeSet, HashMap, HashSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
@@ -70,8 +71,8 @@ pub struct Tunnel<TRoleState> {
impl ClientTunnel {
pub fn new(
private_key: StaticSecret,
tcp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::TcpSocket>>,
udp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::UdpSocket>>,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
known_hosts: HashMap<String, Vec<IpAddr>>,
) -> std::io::Result<Self> {
Ok(Self {

View File

@@ -1,29 +1,23 @@
use core::slice;
use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState};
use socket_factory::SocketFactory;
use socket_factory::{DatagramIn, DatagramOut, SocketFactory, UdpSocket};
use std::{
collections::VecDeque,
io::{self, IoSliceMut},
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
task::{ready, Context, Poll},
};
use tokio::{io::Interest, net::UdpSocket};
use crate::Result;
const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0);
#[derive(Default)]
pub(crate) struct Sockets {
socket_v4: Option<Socket>,
socket_v6: Option<Socket>,
socket_v4: Option<UdpSocket>,
socket_v6: Option<UdpSocket>,
}
impl Sockets {
pub fn rebind(
&mut self,
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
) -> io::Result<()> {
let socket_v4 = Socket::ip4(socket_factory);
let socket_v6 = Socket::ip6(socket_factory);
pub fn rebind(&mut self, socket_factory: &dyn SocketFactory<UdpSocket>) -> io::Result<()> {
let socket_v4 = socket_factory(&SocketAddr::V4(UNSPECIFIED_V4_SOCKET));
let socket_v6 = socket_factory(&SocketAddr::V6(UNSPECIFIED_V6_SOCKET));
match (socket_v4.as_ref(), socket_v6.as_ref()) {
(Err(e), Ok(_)) => {
@@ -65,8 +59,8 @@ impl Sockets {
Poll::Ready(Ok(()))
}
pub fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
let socket = match transmit.dst {
pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> {
let socket = match datagram.dst {
SocketAddr::V4(dst) => self.socket_v4.as_mut().ok_or(io::Error::new(
io::ErrorKind::NotConnected,
format!("failed send packet to {dst}: no IPv4 socket"),
@@ -76,7 +70,7 @@ impl Sockets {
format!("failed send packet to {dst}: no IPv6 socket"),
))?,
};
socket.send(transmit)?;
socket.send(datagram)?;
Ok(())
}
@@ -86,7 +80,7 @@ impl Sockets {
ip4_buffer: &'b mut [u8],
ip6_buffer: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<impl Iterator<Item = Received<'b>>>> {
) -> Poll<io::Result<impl Iterator<Item = DatagramIn<'b>>>> {
let mut iter = PacketIter::new();
if let Some(Poll::Ready(packets)) = self
@@ -133,10 +127,10 @@ impl<T4, T6> PacketIter<T4, T6> {
impl<'a, T4, T6> Iterator for PacketIter<T4, T6>
where
T4: Iterator<Item = Received<'a>>,
T6: Iterator<Item = Received<'a>>,
T4: Iterator<Item = DatagramIn<'a>>,
T6: Iterator<Item = DatagramIn<'a>>,
{
type Item = Received<'a>;
type Item = DatagramIn<'a>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(packet) = self.ip4.as_mut().and_then(|i| i.next()) {
@@ -150,160 +144,3 @@ where
None
}
}
pub struct Received<'a> {
pub local: SocketAddr,
pub from: SocketAddr,
pub packet: &'a [u8],
}
struct Socket {
state: UdpSocketState,
port: u16,
socket: UdpSocket,
buffered_transmits: VecDeque<snownet::Transmit<'static>>,
}
impl Socket {
fn ip(
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
addr: &SocketAddr,
) -> Result<Socket> {
let socket = socket_factory(addr)?;
let port = socket.local_addr()?.port();
Ok(Socket {
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
port,
socket,
buffered_transmits: VecDeque::new(),
})
}
fn ip4(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
Self::ip(
socket_factory,
&SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)),
)
}
fn ip6(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
Self::ip(
socket_factory,
&SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)),
)
}
#[allow(clippy::type_complexity)]
fn poll_recv_from<'b>(
&self,
buffer: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<impl Iterator<Item = Received<'b>>>> {
let Socket {
port,
socket,
state,
..
} = self;
let bufs = &mut [IoSliceMut::new(buffer)];
let mut meta = RecvMeta::default();
loop {
ready!(socket.poll_recv_ready(cx))?;
if let Ok(len) = socket.try_io(Interest::READABLE, || {
state.recv((&socket).into(), bufs, slice::from_mut(&mut meta))
}) {
debug_assert_eq!(len, 1);
if meta.len == 0 {
continue;
}
let Some(local_ip) = meta.dst_ip else {
tracing::warn!("Skipping packet without local IP");
continue;
};
let local = SocketAddr::new(local_ip, *port);
let iter = buffer[..meta.len]
.chunks(meta.stride)
.map(move |packet| Received {
local,
from: meta.addr,
packet,
})
.inspect(|r| {
tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len());
});
return Poll::Ready(Ok(iter));
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
ready!(self.socket.poll_send_ready(cx))?; // Ensure we are ready to send.
let Some(transmit) = self.buffered_transmits.pop_front() else {
break;
};
match self.try_send(&transmit) {
Ok(()) => continue, // Try to send another packet.
Err(e) => {
self.buffered_transmits.push_front(transmit); // Don't lose the packet if we fail.
if e.kind() == io::ErrorKind::WouldBlock {
continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`.
}
return Poll::Ready(Err(e));
}
}
}
assert!(self.buffered_transmits.is_empty());
Poll::Ready(Ok(()))
}
fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
tracing::trace!(target: "wire::net::send", src = ?transmit.src, dst = %transmit.dst, num_bytes = %transmit.payload.len());
debug_assert!(
self.buffered_transmits.len() < 10_000,
"We are not flushing the packets for some reason"
);
match self.try_send(&transmit) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
tracing::trace!("Buffering packet because socket is busy");
self.buffered_transmits.push_back(transmit.into_owned());
Ok(())
}
Err(e) => Err(e),
}
}
fn try_send(&self, transmit: &snownet::Transmit) -> io::Result<()> {
let transmit = quinn_udp::Transmit {
destination: transmit.dst,
ecn: None,
contents: &transmit.payload,
segment_size: None,
src_ip: transmit.src.map(|s| s.ip()),
};
self.socket.try_io(Interest::WRITABLE, || {
self.state.send((&self.socket).into(), &transmit)
})
}
}

View File

@@ -4,6 +4,7 @@ use super::TOKEN_ENV_KEY;
use anyhow::{bail, Result};
use firezone_bin_shared::FIREZONE_MARK;
use nix::sys::socket::{setsockopt, sockopt};
use socket_factory::{TcpSocket, UdpSocket};
use std::{
io,
net::SocketAddr,
@@ -15,13 +16,13 @@ use std::{
const ROOT_GROUP: u32 = 0;
const ROOT_USER: u32 = 0;
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::TcpSocket> {
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<TcpSocket> {
let socket = socket_factory::tcp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
}
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::UdpSocket> {
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<UdpSocket> {
let socket = socket_factory::udp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)

View File

@@ -17,7 +17,7 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat};
use rand_core::{OsRng, RngCore};
use secrecy::{ExposeSecret, Secret};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use socket_factory::SocketFactory;
use socket_factory::{SocketFactory, TcpSocket};
use std::task::{Context, Poll, Waker};
use tokio::net::TcpStream;
use tokio_tungstenite::client_async_tls;
@@ -35,7 +35,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
waker: Option<Waker>,
pending_messages: VecDeque<String>,
next_request_id: Arc<AtomicU64>,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
heartbeat: Heartbeat,
@@ -67,7 +67,7 @@ impl State {
fn connect(
url: Secret<LoginUrl>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Self {
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
}
@@ -76,7 +76,7 @@ impl State {
async fn create_and_connect_websocket(
url: Secret<LoginUrl>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
@@ -89,7 +89,7 @@ async fn create_and_connect_websocket(
async fn make_socket(
url: &Url,
socket_factory: &dyn SocketFactory<tokio::net::TcpSocket>,
socket_factory: &dyn SocketFactory<TcpSocket>,
) -> Result<TcpStream, InternalError> {
let port = url
.port_or_known_default()
@@ -229,7 +229,7 @@ where
login: &'static str,
init_req: TInitReq,
reconnect_backoff: ExponentialBackoff,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> io::Result<Self> {
let next_request_id = Arc::new(AtomicU64::new(0));

View File

@@ -4,5 +4,12 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = { version = "0.1", optional = true }
hickory-proto = { workspace = true, optional = true }
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" }
socket2 = { workspace = true }
tokio = { version = "1.38", features = ["net"] }
tracing = "0.1"
[features]
hickory = ["dep:hickory-proto", "dep:async-trait"]

View File

@@ -1,15 +1,20 @@
use std::net::SocketAddr;
use std::{
borrow::Cow,
collections::VecDeque,
io::{self, IoSliceMut},
net::SocketAddr,
slice,
task::{ready, Context, Poll},
};
use socket2::SockAddr;
use tokio::io::Interest;
pub trait SocketFactory<S>: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static {}
pub trait SocketFactory<S>: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
impl<F, S> SocketFactory<S> for F where
F: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static
{
}
impl<F, S> SocketFactory<S> for F where F: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
pub fn tcp(addr: &SocketAddr) -> io::Result<TcpSocket> {
let socket = match addr {
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
@@ -17,9 +22,10 @@ pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
socket.set_nodelay(true)?;
Ok(socket)
Ok(TcpSocket { inner: socket })
}
pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
pub fn udp(addr: &SocketAddr) -> io::Result<UdpSocket> {
let addr: SockAddr = (*addr).into();
let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?;
@@ -31,5 +37,263 @@ pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
socket.set_nonblocking(true)?;
socket.bind(&addr)?;
std::net::UdpSocket::from(socket).try_into()
let socket = std::net::UdpSocket::from(socket);
let socket = tokio::net::UdpSocket::try_from(socket)?;
let socket = UdpSocket::new(socket)?;
Ok(socket)
}
pub struct TcpSocket {
inner: tokio::net::TcpSocket,
}
impl TcpSocket {
pub async fn connect(self, addr: SocketAddr) -> io::Result<tokio::net::TcpStream> {
self.inner.connect(addr).await
}
}
#[cfg(unix)]
impl std::os::fd::AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(unix)]
impl std::os::fd::AsFd for TcpSocket {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.inner.as_fd()
}
}
pub struct UdpSocket {
inner: tokio::net::UdpSocket,
state: quinn_udp::UdpSocketState,
port: u16,
buffered_datagrams: VecDeque<DatagramOut<'static>>,
}
impl UdpSocket {
fn new(inner: tokio::net::UdpSocket) -> io::Result<Self> {
let port = inner.local_addr()?.port();
Ok(UdpSocket {
state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?,
port,
inner,
buffered_datagrams: VecDeque::new(),
})
}
}
#[cfg(unix)]
impl std::os::fd::AsRawFd for UdpSocket {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(unix)]
impl std::os::fd::AsFd for UdpSocket {
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
self.inner.as_fd()
}
}
/// An inbound UDP datagram.
pub struct DatagramIn<'a> {
pub local: SocketAddr,
pub from: SocketAddr,
pub packet: &'a [u8],
}
/// An outbound UDP datagram.
pub struct DatagramOut<'a> {
pub src: Option<SocketAddr>,
pub dst: SocketAddr,
pub packet: Cow<'a, [u8]>,
}
impl<'a> DatagramOut<'a> {
fn into_owned(self) -> DatagramOut<'static> {
DatagramOut {
src: self.src,
dst: self.dst,
packet: Cow::Owned(self.packet.into_owned()),
}
}
}
impl UdpSocket {
#[allow(clippy::type_complexity)]
pub fn poll_recv_from<'b>(
&self,
buffer: &'b mut [u8],
cx: &mut Context<'_>,
) -> Poll<io::Result<impl Iterator<Item = DatagramIn<'b>>>> {
let Self {
port, inner, state, ..
} = self;
let bufs = &mut [IoSliceMut::new(buffer)];
let mut meta = quinn_udp::RecvMeta::default();
loop {
ready!(inner.poll_recv_ready(cx))?;
if let Ok(len) = inner.try_io(Interest::READABLE, || {
state.recv((&inner).into(), bufs, slice::from_mut(&mut meta))
}) {
debug_assert_eq!(len, 1);
if meta.len == 0 {
continue;
}
let Some(local_ip) = meta.dst_ip else {
tracing::warn!("Skipping packet without local IP");
continue;
};
let local = SocketAddr::new(local_ip, *port);
let iter = buffer[..meta.len]
.chunks(meta.stride)
.map(move |packet| DatagramIn {
local,
from: meta.addr,
packet,
})
.inspect(|r| {
tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len());
});
return Poll::Ready(Ok(iter));
}
}
}
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
ready!(self.inner.poll_send_ready(cx))?; // Ensure we are ready to send.
let Some(transmit) = self.buffered_datagrams.pop_front() else {
break;
};
match self.try_send(&transmit) {
Ok(()) => continue, // Try to send another packet.
Err(e) => {
self.buffered_datagrams.push_front(transmit); // Don't lose the packet if we fail.
if e.kind() == io::ErrorKind::WouldBlock {
continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`.
}
return Poll::Ready(Err(e));
}
}
}
assert!(self.buffered_datagrams.is_empty());
Poll::Ready(Ok(()))
}
pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> {
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, num_bytes = %datagram.packet.len());
debug_assert!(
self.buffered_datagrams.len() < 10_000,
"We are not flushing the packets for some reason"
);
match self.try_send(&datagram) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
tracing::trace!("Buffering packet because socket is busy");
self.buffered_datagrams.push_back(datagram.into_owned());
Ok(())
}
Err(e) => Err(e),
}
}
pub fn try_send(&self, transmit: &DatagramOut) -> io::Result<()> {
let transmit = quinn_udp::Transmit {
destination: transmit.dst,
ecn: None,
contents: &transmit.packet,
segment_size: None,
src_ip: transmit.src.map(|s| s.ip()),
};
self.inner.try_io(Interest::WRITABLE, || {
self.state.send((&self.inner).into(), &transmit)
})
}
}
#[cfg(feature = "hickory")]
mod hickory {
use super::*;
use hickory_proto::{
udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime,
};
use tokio::net::UdpSocket as TokioUdpSocket;
#[async_trait::async_trait]
impl UdpSocketTrait for crate::UdpSocket {
/// setups up a "client" udp connection that will only receive packets from the associated address
async fn connect(addr: SocketAddr) -> io::Result<Self> {
let inner = <TokioUdpSocket as UdpSocketTrait>::connect(addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
/// same as connect, but binds to the specified local address for sending address
async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
let inner =
<TokioUdpSocket as UdpSocketTrait>::connect_with_bind(addr, bind_addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
/// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
async fn bind(addr: SocketAddr) -> io::Result<Self> {
let inner = <TokioUdpSocket as UdpSocketTrait>::bind(addr).await?;
let socket = Self::new(inner)?;
Ok(socket)
}
}
#[cfg(feature = "hickory")]
impl DnsUdpSocketTrait for crate::UdpSocket {
type Time = TokioTime;
fn poll_recv_from(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
<TokioUdpSocket as DnsUdpSocketTrait>::poll_recv_from(&self.inner, cx, buf)
}
fn poll_send_to(
&self,
cx: &mut Context<'_>,
buf: &[u8],
target: SocketAddr,
) -> Poll<io::Result<usize>> {
<TokioUdpSocket as DnsUdpSocketTrait>::poll_send_to(&self.inner, cx, buf, target)
}
}
}