mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): encapsulate UDP and TCP sockets (#6028)
As part of debugging full-route tunneling on Windows, we discovered that we need to always explicitly choose the interface through which we want to send packets, otherwise Windows may cause a routing loop by routing our packets back into the TUN device. We already have a `SocketFactory` abstraction in `connlib` that is used by each platform to customise the setup of each socket to prevent routing loops. So far, this abstraction directly returns tokio sockets which don't allow us to intercept the actual sending of packets. For some of our traffic, i.e. the UDP packets exchanged with relays, we don't specify a source address. To make full-route work on Windows, we need to intercept these packets and explicitly set the source address. To achieve that, we introduce dedicated `TcpSocket` and `UdpSocket` structs within `socket-factory`. With this in place, we will be able to add Windows-conditional code to looks up and sets the source address of outgoing UDP packets. For TCP sockets, the lookup will happen prior to connecting to the address and used to bind to the correct interface. Related: #2667. Related: #5955.
This commit is contained in:
9
rust/Cargo.lock
generated
9
rust/Cargo.lock
generated
@@ -2006,7 +2006,6 @@ dependencies = [
|
||||
"itertools 0.13.0",
|
||||
"proptest",
|
||||
"proptest-state-machine",
|
||||
"quinn-udp",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"rangemap",
|
||||
@@ -4746,8 +4745,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.5.2"
|
||||
source = "git+https://github.com/quinn-rs/quinn?branch=main#3f489e2eab014ddd04de58e570ba56e9b027f0bc"
|
||||
version = "0.5.4"
|
||||
source = "git+https://github.com/quinn-rs/quinn?branch=main#061a74fb6ef67b12f78bc2a3cfc9906e54762eeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"once_cell",
|
||||
@@ -5626,8 +5625,12 @@ dependencies = [
|
||||
name = "socket-factory"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"hickory-proto",
|
||||
"quinn-udp",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -19,7 +19,7 @@ use jni::{
|
||||
};
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use secrecy::{Secret, SecretString};
|
||||
use socket_factory::SocketFactory;
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc};
|
||||
use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
@@ -532,9 +532,7 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
|
||||
session.inner.set_tun(Box::new(tun));
|
||||
}
|
||||
|
||||
fn protected_tcp_socket_factory(
|
||||
callbacks: CallbackHandler,
|
||||
) -> impl SocketFactory<tokio::net::TcpSocket> {
|
||||
fn protected_tcp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory<TcpSocket> {
|
||||
move |addr| {
|
||||
let socket = socket_factory::tcp(addr)?;
|
||||
callbacks.protect(socket.as_raw_fd())?;
|
||||
@@ -542,9 +540,7 @@ fn protected_tcp_socket_factory(
|
||||
}
|
||||
}
|
||||
|
||||
fn protected_udp_socket_factory(
|
||||
callbacks: CallbackHandler,
|
||||
) -> impl SocketFactory<tokio::net::UdpSocket> {
|
||||
fn protected_udp_socket_factory(callbacks: CallbackHandler) -> impl SocketFactory<UdpSocket> {
|
||||
move |addr| {
|
||||
let socket = socket_factory::udp(addr)?;
|
||||
callbacks.protect(socket.as_raw_fd())?;
|
||||
|
||||
@@ -11,7 +11,7 @@ use eventloop::Command;
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use messages::{IngressMessages, ReplyMessages};
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use socket_factory::SocketFactory;
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -36,8 +36,8 @@ pub struct Session {
|
||||
|
||||
/// Arguments for `connect`, since Clippy said 8 args is too many
|
||||
pub struct ConnectArgs<CB> {
|
||||
pub tcp_socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
pub udp_socket_factory: Arc<dyn SocketFactory<tokio::net::UdpSocket>>,
|
||||
pub tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
pub udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
pub private_key: StaticSecret,
|
||||
pub callbacks: CB,
|
||||
}
|
||||
|
||||
@@ -22,13 +22,12 @@ ip_network = { version = "0.4", default-features = false }
|
||||
ip_network_table = { version = "0.2", default-features = false }
|
||||
itertools = { version = "0.13", default-features = false, features = ["use_std"] }
|
||||
proptest = { version = "1", optional = true }
|
||||
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" }
|
||||
rand_core = { version = "0.6", default-features = false, features = ["getrandom"] }
|
||||
rangemap = "1.5.1"
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
|
||||
snownet = { workspace = true }
|
||||
socket-factory = { workspace = true }
|
||||
socket-factory = { workspace = true, features = ["hickory"] }
|
||||
socket2 = { workspace = true }
|
||||
thiserror = { version = "1.0", default-features = false }
|
||||
tokio = { workspace = true }
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
use crate::{
|
||||
device_channel::Device,
|
||||
dns::DnsQuery,
|
||||
sockets::{Received, Sockets},
|
||||
};
|
||||
use crate::{device_channel::Device, dns::DnsQuery, sockets::Sockets};
|
||||
use connlib_shared::messages::DnsServer;
|
||||
use futures::Future;
|
||||
use futures_bounded::FuturesTupleSet;
|
||||
@@ -15,7 +11,7 @@ use hickory_resolver::{
|
||||
AsyncResolver, TokioHandle,
|
||||
};
|
||||
use ip_packet::{IpPacket, MutableIpPacket};
|
||||
use socket_factory::SocketFactory;
|
||||
use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
@@ -25,7 +21,6 @@ use std::{
|
||||
task::{ready, Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::net::{TcpSocket, UdpSocket};
|
||||
|
||||
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
|
||||
|
||||
@@ -94,7 +89,7 @@ impl Io {
|
||||
ip4_buffer: &'b mut [u8],
|
||||
ip6_bffer: &'b mut [u8],
|
||||
device_buffer: &'b mut [u8],
|
||||
) -> Poll<io::Result<Input<'b, impl Iterator<Item = Received<'b>>>>> {
|
||||
) -> Poll<io::Result<Input<'b, impl Iterator<Item = DatagramIn<'b>>>>> {
|
||||
if let Poll::Ready((response, query)) = self.forwarded_dns_queries.poll_unpin(cx) {
|
||||
return Poll::Ready(Ok(Input::DnsResponse(query, response)));
|
||||
}
|
||||
@@ -185,7 +180,11 @@ impl Io {
|
||||
}
|
||||
|
||||
pub fn send_network(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
|
||||
self.sockets.send(transmit)?;
|
||||
self.sockets.send(DatagramOut {
|
||||
src: transmit.src,
|
||||
dst: transmit.dst,
|
||||
packet: transmit.payload,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ use connlib_shared::{
|
||||
};
|
||||
use io::Io;
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap, HashSet},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
@@ -70,8 +71,8 @@ pub struct Tunnel<TRoleState> {
|
||||
impl ClientTunnel {
|
||||
pub fn new(
|
||||
private_key: StaticSecret,
|
||||
tcp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::UdpSocket>>,
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
known_hosts: HashMap<String, Vec<IpAddr>>,
|
||||
) -> std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
use core::slice;
|
||||
use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState};
|
||||
use socket_factory::SocketFactory;
|
||||
use socket_factory::{DatagramIn, DatagramOut, SocketFactory, UdpSocket};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
io::{self, IoSliceMut},
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
io,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::{io::Interest, net::UdpSocket};
|
||||
|
||||
use crate::Result;
|
||||
const UNSPECIFIED_V4_SOCKET: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
|
||||
const UNSPECIFIED_V6_SOCKET: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0);
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Sockets {
|
||||
socket_v4: Option<Socket>,
|
||||
socket_v6: Option<Socket>,
|
||||
socket_v4: Option<UdpSocket>,
|
||||
socket_v6: Option<UdpSocket>,
|
||||
}
|
||||
|
||||
impl Sockets {
|
||||
pub fn rebind(
|
||||
&mut self,
|
||||
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
|
||||
) -> io::Result<()> {
|
||||
let socket_v4 = Socket::ip4(socket_factory);
|
||||
let socket_v6 = Socket::ip6(socket_factory);
|
||||
pub fn rebind(&mut self, socket_factory: &dyn SocketFactory<UdpSocket>) -> io::Result<()> {
|
||||
let socket_v4 = socket_factory(&SocketAddr::V4(UNSPECIFIED_V4_SOCKET));
|
||||
let socket_v6 = socket_factory(&SocketAddr::V6(UNSPECIFIED_V6_SOCKET));
|
||||
|
||||
match (socket_v4.as_ref(), socket_v6.as_ref()) {
|
||||
(Err(e), Ok(_)) => {
|
||||
@@ -65,8 +59,8 @@ impl Sockets {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
pub fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
|
||||
let socket = match transmit.dst {
|
||||
pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> {
|
||||
let socket = match datagram.dst {
|
||||
SocketAddr::V4(dst) => self.socket_v4.as_mut().ok_or(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
format!("failed send packet to {dst}: no IPv4 socket"),
|
||||
@@ -76,7 +70,7 @@ impl Sockets {
|
||||
format!("failed send packet to {dst}: no IPv6 socket"),
|
||||
))?,
|
||||
};
|
||||
socket.send(transmit)?;
|
||||
socket.send(datagram)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -86,7 +80,7 @@ impl Sockets {
|
||||
ip4_buffer: &'b mut [u8],
|
||||
ip6_buffer: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<impl Iterator<Item = Received<'b>>>> {
|
||||
) -> Poll<io::Result<impl Iterator<Item = DatagramIn<'b>>>> {
|
||||
let mut iter = PacketIter::new();
|
||||
|
||||
if let Some(Poll::Ready(packets)) = self
|
||||
@@ -133,10 +127,10 @@ impl<T4, T6> PacketIter<T4, T6> {
|
||||
|
||||
impl<'a, T4, T6> Iterator for PacketIter<T4, T6>
|
||||
where
|
||||
T4: Iterator<Item = Received<'a>>,
|
||||
T6: Iterator<Item = Received<'a>>,
|
||||
T4: Iterator<Item = DatagramIn<'a>>,
|
||||
T6: Iterator<Item = DatagramIn<'a>>,
|
||||
{
|
||||
type Item = Received<'a>;
|
||||
type Item = DatagramIn<'a>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if let Some(packet) = self.ip4.as_mut().and_then(|i| i.next()) {
|
||||
@@ -150,160 +144,3 @@ where
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Received<'a> {
|
||||
pub local: SocketAddr,
|
||||
pub from: SocketAddr,
|
||||
pub packet: &'a [u8],
|
||||
}
|
||||
|
||||
struct Socket {
|
||||
state: UdpSocketState,
|
||||
port: u16,
|
||||
socket: UdpSocket,
|
||||
|
||||
buffered_transmits: VecDeque<snownet::Transmit<'static>>,
|
||||
}
|
||||
|
||||
impl Socket {
|
||||
fn ip(
|
||||
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
|
||||
addr: &SocketAddr,
|
||||
) -> Result<Socket> {
|
||||
let socket = socket_factory(addr)?;
|
||||
let port = socket.local_addr()?.port();
|
||||
|
||||
Ok(Socket {
|
||||
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
|
||||
port,
|
||||
socket,
|
||||
buffered_transmits: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ip4(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
|
||||
Self::ip(
|
||||
socket_factory,
|
||||
&SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)),
|
||||
)
|
||||
}
|
||||
|
||||
fn ip6(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
|
||||
Self::ip(
|
||||
socket_factory,
|
||||
&SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn poll_recv_from<'b>(
|
||||
&self,
|
||||
buffer: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<impl Iterator<Item = Received<'b>>>> {
|
||||
let Socket {
|
||||
port,
|
||||
socket,
|
||||
state,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let bufs = &mut [IoSliceMut::new(buffer)];
|
||||
let mut meta = RecvMeta::default();
|
||||
|
||||
loop {
|
||||
ready!(socket.poll_recv_ready(cx))?;
|
||||
|
||||
if let Ok(len) = socket.try_io(Interest::READABLE, || {
|
||||
state.recv((&socket).into(), bufs, slice::from_mut(&mut meta))
|
||||
}) {
|
||||
debug_assert_eq!(len, 1);
|
||||
|
||||
if meta.len == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(local_ip) = meta.dst_ip else {
|
||||
tracing::warn!("Skipping packet without local IP");
|
||||
continue;
|
||||
};
|
||||
|
||||
let local = SocketAddr::new(local_ip, *port);
|
||||
|
||||
let iter = buffer[..meta.len]
|
||||
.chunks(meta.stride)
|
||||
.map(move |packet| Received {
|
||||
local,
|
||||
from: meta.addr,
|
||||
packet,
|
||||
})
|
||||
.inspect(|r| {
|
||||
tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len());
|
||||
});
|
||||
|
||||
return Poll::Ready(Ok(iter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
ready!(self.socket.poll_send_ready(cx))?; // Ensure we are ready to send.
|
||||
|
||||
let Some(transmit) = self.buffered_transmits.pop_front() else {
|
||||
break;
|
||||
};
|
||||
|
||||
match self.try_send(&transmit) {
|
||||
Ok(()) => continue, // Try to send another packet.
|
||||
Err(e) => {
|
||||
self.buffered_transmits.push_front(transmit); // Don't lose the packet if we fail.
|
||||
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`.
|
||||
}
|
||||
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(self.buffered_transmits.is_empty());
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn send(&mut self, transmit: snownet::Transmit) -> io::Result<()> {
|
||||
tracing::trace!(target: "wire::net::send", src = ?transmit.src, dst = %transmit.dst, num_bytes = %transmit.payload.len());
|
||||
|
||||
debug_assert!(
|
||||
self.buffered_transmits.len() < 10_000,
|
||||
"We are not flushing the packets for some reason"
|
||||
);
|
||||
|
||||
match self.try_send(&transmit) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
tracing::trace!("Buffering packet because socket is busy");
|
||||
|
||||
self.buffered_transmits.push_back(transmit.into_owned());
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_send(&self, transmit: &snownet::Transmit) -> io::Result<()> {
|
||||
let transmit = quinn_udp::Transmit {
|
||||
destination: transmit.dst,
|
||||
ecn: None,
|
||||
contents: &transmit.payload,
|
||||
segment_size: None,
|
||||
src_ip: transmit.src.map(|s| s.ip()),
|
||||
};
|
||||
|
||||
self.socket.try_io(Interest::WRITABLE, || {
|
||||
self.state.send((&self.socket).into(), &transmit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::TOKEN_ENV_KEY;
|
||||
use anyhow::{bail, Result};
|
||||
use firezone_bin_shared::FIREZONE_MARK;
|
||||
use nix::sys::socket::{setsockopt, sockopt};
|
||||
use socket_factory::{TcpSocket, UdpSocket};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
@@ -15,13 +16,13 @@ use std::{
|
||||
const ROOT_GROUP: u32 = 0;
|
||||
const ROOT_USER: u32 = 0;
|
||||
|
||||
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::TcpSocket> {
|
||||
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<TcpSocket> {
|
||||
let socket = socket_factory::tcp(socket_addr)?;
|
||||
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::UdpSocket> {
|
||||
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<UdpSocket> {
|
||||
let socket = socket_factory::udp(socket_addr)?;
|
||||
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
|
||||
Ok(socket)
|
||||
|
||||
@@ -17,7 +17,7 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use secrecy::{ExposeSecret, Secret};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use socket_factory::SocketFactory;
|
||||
use socket_factory::{SocketFactory, TcpSocket};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::client_async_tls;
|
||||
@@ -35,7 +35,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
||||
waker: Option<Waker>,
|
||||
pending_messages: VecDeque<String>,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
|
||||
heartbeat: Heartbeat,
|
||||
|
||||
@@ -67,7 +67,7 @@ impl State {
|
||||
fn connect(
|
||||
url: Secret<LoginUrl>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
) -> Self {
|
||||
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
|
||||
}
|
||||
@@ -76,7 +76,7 @@ impl State {
|
||||
async fn create_and_connect_websocket(
|
||||
url: Secret<LoginUrl>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
|
||||
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
|
||||
|
||||
@@ -89,7 +89,7 @@ async fn create_and_connect_websocket(
|
||||
|
||||
async fn make_socket(
|
||||
url: &Url,
|
||||
socket_factory: &dyn SocketFactory<tokio::net::TcpSocket>,
|
||||
socket_factory: &dyn SocketFactory<TcpSocket>,
|
||||
) -> Result<TcpStream, InternalError> {
|
||||
let port = url
|
||||
.port_or_known_default()
|
||||
@@ -229,7 +229,7 @@ where
|
||||
login: &'static str,
|
||||
init_req: TInitReq,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
) -> io::Result<Self> {
|
||||
let next_request_id = Arc::new(AtomicU64::new(0));
|
||||
|
||||
|
||||
@@ -4,5 +4,12 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
async-trait = { version = "0.1", optional = true }
|
||||
hickory-proto = { workspace = true, optional = true }
|
||||
quinn-udp = { git = "https://github.com/quinn-rs/quinn", branch = "main" }
|
||||
socket2 = { workspace = true }
|
||||
tokio = { version = "1.38", features = ["net"] }
|
||||
tracing = "0.1"
|
||||
|
||||
[features]
|
||||
hickory = ["dep:hickory-proto", "dep:async-trait"]
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::VecDeque,
|
||||
io::{self, IoSliceMut},
|
||||
net::SocketAddr,
|
||||
slice,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use socket2::SockAddr;
|
||||
use tokio::io::Interest;
|
||||
|
||||
pub trait SocketFactory<S>: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static {}
|
||||
pub trait SocketFactory<S>: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
|
||||
|
||||
impl<F, S> SocketFactory<S> for F where
|
||||
F: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static
|
||||
{
|
||||
}
|
||||
impl<F, S> SocketFactory<S> for F where F: Fn(&SocketAddr) -> io::Result<S> + Send + Sync + 'static {}
|
||||
|
||||
pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
|
||||
pub fn tcp(addr: &SocketAddr) -> io::Result<TcpSocket> {
|
||||
let socket = match addr {
|
||||
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
|
||||
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
|
||||
@@ -17,9 +22,10 @@ pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
Ok(socket)
|
||||
Ok(TcpSocket { inner: socket })
|
||||
}
|
||||
pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
|
||||
|
||||
pub fn udp(addr: &SocketAddr) -> io::Result<UdpSocket> {
|
||||
let addr: SockAddr = (*addr).into();
|
||||
let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?;
|
||||
|
||||
@@ -31,5 +37,263 @@ pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
|
||||
socket.set_nonblocking(true)?;
|
||||
socket.bind(&addr)?;
|
||||
|
||||
std::net::UdpSocket::from(socket).try_into()
|
||||
let socket = std::net::UdpSocket::from(socket);
|
||||
let socket = tokio::net::UdpSocket::try_from(socket)?;
|
||||
let socket = UdpSocket::new(socket)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub struct TcpSocket {
|
||||
inner: tokio::net::TcpSocket,
|
||||
}
|
||||
|
||||
impl TcpSocket {
|
||||
pub async fn connect(self, addr: SocketAddr) -> io::Result<tokio::net::TcpStream> {
|
||||
self.inner.connect(addr).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl std::os::fd::AsRawFd for TcpSocket {
|
||||
fn as_raw_fd(&self) -> std::os::fd::RawFd {
|
||||
self.inner.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl std::os::fd::AsFd for TcpSocket {
|
||||
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
|
||||
self.inner.as_fd()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpSocket {
|
||||
inner: tokio::net::UdpSocket,
|
||||
state: quinn_udp::UdpSocketState,
|
||||
|
||||
port: u16,
|
||||
|
||||
buffered_datagrams: VecDeque<DatagramOut<'static>>,
|
||||
}
|
||||
|
||||
impl UdpSocket {
|
||||
fn new(inner: tokio::net::UdpSocket) -> io::Result<Self> {
|
||||
let port = inner.local_addr()?.port();
|
||||
|
||||
Ok(UdpSocket {
|
||||
state: quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&inner))?,
|
||||
port,
|
||||
inner,
|
||||
buffered_datagrams: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl std::os::fd::AsRawFd for UdpSocket {
|
||||
fn as_raw_fd(&self) -> std::os::fd::RawFd {
|
||||
self.inner.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl std::os::fd::AsFd for UdpSocket {
|
||||
fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> {
|
||||
self.inner.as_fd()
|
||||
}
|
||||
}
|
||||
|
||||
/// An inbound UDP datagram.
|
||||
pub struct DatagramIn<'a> {
|
||||
pub local: SocketAddr,
|
||||
pub from: SocketAddr,
|
||||
pub packet: &'a [u8],
|
||||
}
|
||||
|
||||
/// An outbound UDP datagram.
|
||||
pub struct DatagramOut<'a> {
|
||||
pub src: Option<SocketAddr>,
|
||||
pub dst: SocketAddr,
|
||||
pub packet: Cow<'a, [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> DatagramOut<'a> {
|
||||
fn into_owned(self) -> DatagramOut<'static> {
|
||||
DatagramOut {
|
||||
src: self.src,
|
||||
dst: self.dst,
|
||||
packet: Cow::Owned(self.packet.into_owned()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpSocket {
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn poll_recv_from<'b>(
|
||||
&self,
|
||||
buffer: &'b mut [u8],
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<impl Iterator<Item = DatagramIn<'b>>>> {
|
||||
let Self {
|
||||
port, inner, state, ..
|
||||
} = self;
|
||||
|
||||
let bufs = &mut [IoSliceMut::new(buffer)];
|
||||
let mut meta = quinn_udp::RecvMeta::default();
|
||||
|
||||
loop {
|
||||
ready!(inner.poll_recv_ready(cx))?;
|
||||
|
||||
if let Ok(len) = inner.try_io(Interest::READABLE, || {
|
||||
state.recv((&inner).into(), bufs, slice::from_mut(&mut meta))
|
||||
}) {
|
||||
debug_assert_eq!(len, 1);
|
||||
|
||||
if meta.len == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(local_ip) = meta.dst_ip else {
|
||||
tracing::warn!("Skipping packet without local IP");
|
||||
continue;
|
||||
};
|
||||
|
||||
let local = SocketAddr::new(local_ip, *port);
|
||||
|
||||
let iter = buffer[..meta.len]
|
||||
.chunks(meta.stride)
|
||||
.map(move |packet| DatagramIn {
|
||||
local,
|
||||
from: meta.addr,
|
||||
packet,
|
||||
})
|
||||
.inspect(|r| {
|
||||
tracing::trace!(target: "wire::net::recv", src = %r.from, dst = %r.local, num_bytes = %r.packet.len());
|
||||
});
|
||||
|
||||
return Poll::Ready(Ok(iter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
loop {
|
||||
ready!(self.inner.poll_send_ready(cx))?; // Ensure we are ready to send.
|
||||
|
||||
let Some(transmit) = self.buffered_datagrams.pop_front() else {
|
||||
break;
|
||||
};
|
||||
|
||||
match self.try_send(&transmit) {
|
||||
Ok(()) => continue, // Try to send another packet.
|
||||
Err(e) => {
|
||||
self.buffered_datagrams.push_front(transmit); // Don't lose the packet if we fail.
|
||||
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
continue; // False positive send-readiness: Loop to `poll_send_ready` and return `Pending`.
|
||||
}
|
||||
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(self.buffered_datagrams.is_empty());
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
pub fn send(&mut self, datagram: DatagramOut) -> io::Result<()> {
|
||||
tracing::trace!(target: "wire::net::send", src = ?datagram.src, dst = %datagram.dst, num_bytes = %datagram.packet.len());
|
||||
|
||||
debug_assert!(
|
||||
self.buffered_datagrams.len() < 10_000,
|
||||
"We are not flushing the packets for some reason"
|
||||
);
|
||||
|
||||
match self.try_send(&datagram) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
tracing::trace!("Buffering packet because socket is busy");
|
||||
|
||||
self.buffered_datagrams.push_back(datagram.into_owned());
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_send(&self, transmit: &DatagramOut) -> io::Result<()> {
|
||||
let transmit = quinn_udp::Transmit {
|
||||
destination: transmit.dst,
|
||||
ecn: None,
|
||||
contents: &transmit.packet,
|
||||
segment_size: None,
|
||||
src_ip: transmit.src.map(|s| s.ip()),
|
||||
};
|
||||
|
||||
self.inner.try_io(Interest::WRITABLE, || {
|
||||
self.state.send((&self.inner).into(), &transmit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "hickory")]
|
||||
mod hickory {
|
||||
use super::*;
|
||||
use hickory_proto::{
|
||||
udp::DnsUdpSocket as DnsUdpSocketTrait, udp::UdpSocket as UdpSocketTrait, TokioTime,
|
||||
};
|
||||
use tokio::net::UdpSocket as TokioUdpSocket;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UdpSocketTrait for crate::UdpSocket {
|
||||
/// setups up a "client" udp connection that will only receive packets from the associated address
|
||||
async fn connect(addr: SocketAddr) -> io::Result<Self> {
|
||||
let inner = <TokioUdpSocket as UdpSocketTrait>::connect(addr).await?;
|
||||
let socket = Self::new(inner)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
/// same as connect, but binds to the specified local address for sending address
|
||||
async fn connect_with_bind(addr: SocketAddr, bind_addr: SocketAddr) -> io::Result<Self> {
|
||||
let inner =
|
||||
<TokioUdpSocket as UdpSocketTrait>::connect_with_bind(addr, bind_addr).await?;
|
||||
let socket = Self::new(inner)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
/// a "server" UDP socket, that bind to the local listening address, and unbound remote address (can receive from anything)
|
||||
async fn bind(addr: SocketAddr) -> io::Result<Self> {
|
||||
let inner = <TokioUdpSocket as UdpSocketTrait>::bind(addr).await?;
|
||||
let socket = Self::new(inner)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "hickory")]
|
||||
impl DnsUdpSocketTrait for crate::UdpSocket {
|
||||
type Time = TokioTime;
|
||||
|
||||
fn poll_recv_from(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<(usize, SocketAddr)>> {
|
||||
<TokioUdpSocket as DnsUdpSocketTrait>::poll_recv_from(&self.inner, cx, buf)
|
||||
}
|
||||
|
||||
fn poll_send_to(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
target: SocketAddr,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
<TokioUdpSocket as DnsUdpSocketTrait>::poll_send_to(&self.inner, cx, buf, target)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user