mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
fix(connlib): protect all sockets from routing loops (#5797)
Currently, only connlib's UDP sockets for sending and receiving STUN & WireGuard traffic are protected from routing loops. This is was done via the `Sockets::with_protect` function. Connlib has additional sockets though: - A TCP socket to the portal. - UDP & TCP sockets for DNS resolution via hickory. Both of these can incur routing loops on certain platforms which becomes evident as we try to implement #2667. To fix this, we generalise the idea of "protecting" a socket via a `SocketFactory` abstraction. By allowing the different platforms to provide a specialised `SocketFactory`, anything Linux-based can give special treatment to the socket before handing it to connlib. As an additional benefit, this allows us to remove the `Sockets` abstraction from connlib's API again because we can now initialise it internally via the provided `SocketFactory` for UDP sockets. --------- Signed-off-by: Gabi <gabrielalejandro7@gmail.com> Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
17
rust/Cargo.lock
generated
17
rust/Cargo.lock
generated
@@ -1066,6 +1066,7 @@ dependencies = [
|
||||
"log",
|
||||
"secrecy",
|
||||
"serde_json",
|
||||
"socket-factory",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -1084,6 +1085,7 @@ dependencies = [
|
||||
"oslog",
|
||||
"secrecy",
|
||||
"serde_json",
|
||||
"socket-factory",
|
||||
"swift-bridge",
|
||||
"swift-bridge-build",
|
||||
"tokio",
|
||||
@@ -1109,6 +1111,7 @@ dependencies = [
|
||||
"secrecy",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket-factory",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
@@ -1873,6 +1876,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"snownet",
|
||||
"socket-factory",
|
||||
"static_assertions",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
@@ -1962,6 +1966,7 @@ dependencies = [
|
||||
"secrecy",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket-factory",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -2003,6 +2008,7 @@ dependencies = [
|
||||
"secrecy",
|
||||
"serde",
|
||||
"sha2",
|
||||
"socket-factory",
|
||||
"socket2 0.5.7",
|
||||
"stun_codec",
|
||||
"test-strategy",
|
||||
@@ -2054,6 +2060,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"snownet",
|
||||
"socket-factory",
|
||||
"socket2 0.5.7",
|
||||
"test-strategy",
|
||||
"thiserror",
|
||||
@@ -3901,6 +3908,7 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"memoffset 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4561,6 +4569,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"socket-factory",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
@@ -5812,6 +5821,14 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket-factory"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"socket2 0.5.7",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.4.10"
|
||||
|
||||
@@ -17,6 +17,7 @@ members = [
|
||||
"phoenix-channel",
|
||||
"relay",
|
||||
"snownet-tests",
|
||||
"socket-factory",
|
||||
]
|
||||
|
||||
resolver = "2"
|
||||
@@ -53,6 +54,8 @@ firezone-tunnel = { path = "connlib/tunnel" }
|
||||
phoenix-channel = { path = "phoenix-channel" }
|
||||
http-health-check = { path = "http-health-check" }
|
||||
ip-packet = { path = "ip-packet" }
|
||||
socket-factory = { path = "socket-factory" }
|
||||
socket2 = { version = "0.5" }
|
||||
|
||||
[workspace.lints.clippy]
|
||||
dbg_macro = "warn"
|
||||
|
||||
@@ -7,6 +7,9 @@ use tracing_subscriber::{
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
/// Mark for Firezone sockets to prevent routing loops on Linux.
|
||||
pub const FIREZONE_MARK: u32 = 0xfd002021;
|
||||
|
||||
#[cfg(any(target_os = "linux", target_os = "windows"))]
|
||||
pub use tun_device_manager::TunDeviceManager;
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
};
|
||||
|
||||
const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `Sockets` until #5797.
|
||||
use crate::FIREZONE_MARK;
|
||||
|
||||
const FILE_ALREADY_EXISTS: i32 = -17;
|
||||
const FIREZONE_TABLE: u32 = 0x2021_fd00;
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ jni = { version = "0.21.1", features = ["invocation"] }
|
||||
log = "0.4"
|
||||
secrecy = { workspace = true }
|
||||
serde_json = "1"
|
||||
socket-factory = { workspace = true }
|
||||
thiserror = "1"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tracing = { workspace = true, features = ["std", "attributes"] }
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
use connlib_client_shared::{
|
||||
callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl,
|
||||
LoginUrlError, Session, Sockets, Tun, V4RouteList, V6RouteList,
|
||||
LoginUrlError, Session, Tun, V4RouteList, V6RouteList,
|
||||
};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use jni::{
|
||||
@@ -15,7 +15,8 @@ use jni::{
|
||||
JNIEnv, JavaVM,
|
||||
};
|
||||
use secrecy::SecretString;
|
||||
use std::{io, net::IpAddr, path::Path};
|
||||
use socket_factory::SocketFactory;
|
||||
use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc};
|
||||
use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr},
|
||||
os::fd::RawFd,
|
||||
@@ -86,16 +87,17 @@ impl CallbackHandler {
|
||||
.and_then(f)
|
||||
}
|
||||
|
||||
fn protect_file_descriptor(&self, file_descriptor: RawFd) -> Result<(), CallbackError> {
|
||||
fn protect(&self, socket: RawFd) -> io::Result<()> {
|
||||
self.env(|mut env| {
|
||||
call_method(
|
||||
&mut env,
|
||||
&self.callback_handler,
|
||||
"protectFileDescriptor",
|
||||
"(I)V",
|
||||
&[JValue::Int(file_descriptor)],
|
||||
&[JValue::Int(socket)],
|
||||
)
|
||||
})
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,18 +359,10 @@ fn connect(
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let sockets = Sockets::with_protect({
|
||||
let callbacks = callbacks.clone();
|
||||
move |fd| {
|
||||
callbacks
|
||||
.protect_file_descriptor(fd)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
});
|
||||
|
||||
let args = ConnectArgs {
|
||||
url,
|
||||
sockets,
|
||||
tcp_socket_factory: Arc::new(protected_tcp_socket_factory(callbacks.clone())),
|
||||
udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())),
|
||||
private_key,
|
||||
os_version_override: Some(os_version),
|
||||
app_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -523,3 +517,23 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
|
||||
|
||||
session.inner.set_tun(tun);
|
||||
}
|
||||
|
||||
fn protected_tcp_socket_factory(
|
||||
callbacks: CallbackHandler,
|
||||
) -> impl SocketFactory<tokio::net::TcpSocket> {
|
||||
move |addr| {
|
||||
let socket = socket_factory::tcp(addr)?;
|
||||
callbacks.protect(socket.as_raw_fd())?;
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
fn protected_udp_socket_factory(
|
||||
callbacks: CallbackHandler,
|
||||
) -> impl SocketFactory<tokio::net::UdpSocket> {
|
||||
move |addr| {
|
||||
let socket = socket_factory::udp(addr)?;
|
||||
callbacks.protect(socket.as_raw_fd())?;
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ ip_network = "0.4"
|
||||
libc = "0.2"
|
||||
secrecy = { workspace = true }
|
||||
serde_json = "1"
|
||||
socket-factory = { workspace = true }
|
||||
swift-bridge = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -5,7 +5,7 @@ mod make_writer;
|
||||
|
||||
use connlib_client_shared::{
|
||||
callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl,
|
||||
Session, Sockets, Tun, V4RouteList, V6RouteList,
|
||||
Session, Tun, V4RouteList, V6RouteList,
|
||||
};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
use secrecy::SecretString;
|
||||
@@ -194,7 +194,6 @@ impl WrappedSession {
|
||||
|
||||
let args = ConnectArgs {
|
||||
url,
|
||||
sockets: Sockets::new(),
|
||||
private_key,
|
||||
os_version_override,
|
||||
app_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
@@ -202,6 +201,8 @@ impl WrappedSession {
|
||||
inner: Arc::new(callback_handler),
|
||||
},
|
||||
max_partition_time: Some(MAX_PARTITION_TIME),
|
||||
tcp_socket_factory: Arc::new(socket_factory::tcp),
|
||||
udp_socket_factory: Arc::new(socket_factory::udp),
|
||||
};
|
||||
let session = Session::connect(args, runtime.handle().clone());
|
||||
let _enter = runtime.enter();
|
||||
|
||||
@@ -17,6 +17,7 @@ ip_network = { version = "0.4", default-features = false }
|
||||
phoenix-channel = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0", default-features = false, features = ["std", "derive"] }
|
||||
socket-factory = { workspace = true }
|
||||
time = { version = "0.3.36", features = ["formatting"] }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
|
||||
|
||||
@@ -5,15 +5,17 @@ pub use connlib_shared::{
|
||||
callbacks, keypair, Callbacks, Error, LoginUrl, LoginUrlError, StaticSecret,
|
||||
};
|
||||
pub use eventloop::Eventloop;
|
||||
pub use firezone_tunnel::{Sockets, Tun};
|
||||
pub use firezone_tunnel::Tun;
|
||||
pub use tracing_appender::non_blocking::WorkerGuard;
|
||||
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use connlib_shared::get_user_agent;
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use socket_factory::SocketFactory;
|
||||
use std::collections::HashMap;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
@@ -39,7 +41,8 @@ pub struct Session {
|
||||
/// Arguments for `connect`, since Clippy said 8 args is too many
|
||||
pub struct ConnectArgs<CB> {
|
||||
pub url: LoginUrl,
|
||||
pub sockets: Sockets,
|
||||
pub tcp_socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
pub udp_socket_factory: Arc<dyn SocketFactory<tokio::net::UdpSocket>>,
|
||||
pub private_key: StaticSecret,
|
||||
pub os_version_override: Option<String>,
|
||||
pub app_version: String,
|
||||
@@ -120,11 +123,12 @@ where
|
||||
{
|
||||
let ConnectArgs {
|
||||
url,
|
||||
sockets,
|
||||
private_key,
|
||||
os_version_override,
|
||||
app_version,
|
||||
callbacks,
|
||||
udp_socket_factory,
|
||||
tcp_socket_factory,
|
||||
max_partition_time,
|
||||
} = args;
|
||||
|
||||
@@ -139,7 +143,8 @@ where
|
||||
|
||||
let tunnel = ClientTunnel::new(
|
||||
private_key,
|
||||
sockets,
|
||||
tcp_socket_factory.clone(),
|
||||
udp_socket_factory,
|
||||
callbacks,
|
||||
HashMap::from([(url.host().to_string(), addrs)]),
|
||||
)?;
|
||||
@@ -152,6 +157,7 @@ where
|
||||
ExponentialBackoffBuilder::default()
|
||||
.with_max_elapsed_time(max_partition_time)
|
||||
.build(),
|
||||
tcp_socket_factory,
|
||||
);
|
||||
|
||||
let mut eventloop = Eventloop::new(tunnel, portal, rx);
|
||||
@@ -232,14 +238,18 @@ mod tests {
|
||||
#[cfg(any(target_os = "windows", target_os = "linux"))]
|
||||
async fn device_common() {
|
||||
use firezone_tunnel::Tun;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
let (private_key, _public_key) = connlib_shared::keypair();
|
||||
let sockets = crate::Sockets::new();
|
||||
let callbacks = Callbacks::default();
|
||||
let mut tunnel =
|
||||
firezone_tunnel::ClientTunnel::new(private_key, sockets, callbacks, HashMap::new())
|
||||
.unwrap();
|
||||
let mut tunnel = firezone_tunnel::ClientTunnel::new(
|
||||
private_key,
|
||||
Arc::new(socket_factory::tcp),
|
||||
Arc::new(socket_factory::udp),
|
||||
callbacks,
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
let upstream_dns = vec![([192, 168, 1, 1], 53).into()];
|
||||
let interface = connlib_shared::messages::Interface {
|
||||
ipv4: [100, 71, 96, 96].into(),
|
||||
|
||||
@@ -30,7 +30,8 @@ rangemap = "1.5.1"
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
|
||||
snownet = { workspace = true }
|
||||
socket2 = { version = "0.5" }
|
||||
socket-factory = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
thiserror = { version = "1.0", default-features = false }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
@@ -5,22 +5,29 @@ use crate::{
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use connlib_shared::messages::DnsServer;
|
||||
use futures::Future;
|
||||
use futures_bounded::FuturesTupleSet;
|
||||
use futures_util::FutureExt as _;
|
||||
use hickory_proto::iocompat::AsyncIoTokioAsStd;
|
||||
use hickory_proto::TokioTime;
|
||||
use hickory_resolver::{
|
||||
config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
|
||||
TokioAsyncResolver,
|
||||
name_server::{GenericConnector, RuntimeProvider},
|
||||
AsyncResolver, TokioHandle,
|
||||
};
|
||||
use ip_packet::{IpPacket, MutableIpPacket};
|
||||
use quinn_udp::Transmit;
|
||||
use socket_factory::SocketFactory;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
net::IpAddr,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::net::{TcpSocket, UdpSocket};
|
||||
|
||||
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
|
||||
|
||||
@@ -32,9 +39,13 @@ pub struct Io {
|
||||
device: Device,
|
||||
/// The UDP sockets used to send & receive packets from the network.
|
||||
sockets: Sockets,
|
||||
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
|
||||
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
|
||||
|
||||
upstream_dns_servers: HashMap<IpAddr, TokioAsyncResolver>,
|
||||
upstream_dns_servers: HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>>,
|
||||
forwarded_dns_queries: FuturesTupleSet<
|
||||
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
|
||||
DnsQuery<'static>,
|
||||
@@ -58,13 +69,19 @@ impl Io {
|
||||
/// Creates a new I/O abstraction
|
||||
///
|
||||
/// Must be called within a Tokio runtime context so we can bind the sockets.
|
||||
pub fn new(mut sockets: Sockets) -> io::Result<Self> {
|
||||
sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context.
|
||||
pub fn new(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
) -> io::Result<Self> {
|
||||
let mut sockets = Sockets::default();
|
||||
sockets.rebind(udp_socket_factory.as_ref())?; // Bind sockets on startup. Must happen within a tokio runtime context.
|
||||
|
||||
Ok(Self {
|
||||
device: Device::new(),
|
||||
timeout: None,
|
||||
sockets,
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
upstream_dns_servers: HashMap::default(),
|
||||
forwarded_dns_queries: FuturesTupleSet::new(
|
||||
Duration::from_secs(60),
|
||||
@@ -107,8 +124,10 @@ impl Io {
|
||||
&mut self.device
|
||||
}
|
||||
|
||||
pub fn sockets_mut(&mut self) -> &mut Sockets {
|
||||
&mut self.sockets
|
||||
pub fn rebind_sockets(&mut self) -> io::Result<()> {
|
||||
self.sockets.rebind(self.udp_socket_factory.as_ref())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_upstream_dns_servers(
|
||||
@@ -119,7 +138,13 @@ impl Io {
|
||||
|
||||
self.forwarded_dns_queries =
|
||||
FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE);
|
||||
self.upstream_dns_servers = create_resolvers(dns_servers);
|
||||
self.upstream_dns_servers = create_resolvers(
|
||||
dns_servers,
|
||||
TokioRuntimeProvider::new(
|
||||
self.tcp_socket_factory.clone(),
|
||||
self.udp_socket_factory.clone(),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> {
|
||||
@@ -186,9 +211,65 @@ pub enum DnsQueryError {
|
||||
TooManyQueries,
|
||||
}
|
||||
|
||||
/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`].
|
||||
#[derive(Clone)]
|
||||
struct TokioRuntimeProvider {
|
||||
handle: TokioHandle,
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
}
|
||||
|
||||
impl TokioRuntimeProvider {
|
||||
fn new(
|
||||
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
|
||||
) -> TokioRuntimeProvider {
|
||||
Self {
|
||||
handle: Default::default(),
|
||||
tcp_socket_factory,
|
||||
udp_socket_factory,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeProvider for TokioRuntimeProvider {
|
||||
type Handle = TokioHandle;
|
||||
type Timer = TokioTime;
|
||||
type Udp = UdpSocket;
|
||||
type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
|
||||
|
||||
fn create_handle(&self) -> Self::Handle {
|
||||
self.handle.clone()
|
||||
}
|
||||
|
||||
fn connect_tcp(
|
||||
&self,
|
||||
server_addr: SocketAddr,
|
||||
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
|
||||
let socket = (self.tcp_socket_factory)(&server_addr);
|
||||
Box::pin(async move {
|
||||
let socket = socket?;
|
||||
let stream = socket.connect(server_addr).await?;
|
||||
|
||||
Ok(AsyncIoTokioAsStd(stream))
|
||||
})
|
||||
}
|
||||
|
||||
fn bind_udp(
|
||||
&self,
|
||||
local_addr: SocketAddr,
|
||||
_server_addr: SocketAddr,
|
||||
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
|
||||
let socket = (self.udp_socket_factory)(&local_addr);
|
||||
|
||||
Box::pin(async move { socket })
|
||||
}
|
||||
}
|
||||
|
||||
fn create_resolvers(
|
||||
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
|
||||
) -> HashMap<IpAddr, TokioAsyncResolver> {
|
||||
runtime_provider: TokioRuntimeProvider,
|
||||
) -> HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>> {
|
||||
dns_servers
|
||||
.into_iter()
|
||||
.map(|(sentinel, srv)| {
|
||||
@@ -201,7 +282,11 @@ fn create_resolvers(
|
||||
|
||||
(
|
||||
sentinel,
|
||||
TokioAsyncResolver::tokio(resolver_config, resolver_opts),
|
||||
AsyncResolver::new_with_conn(
|
||||
resolver_config,
|
||||
resolver_opts,
|
||||
GenericConnector::new(runtime_provider.clone()),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
|
||||
@@ -14,6 +14,7 @@ use io::Io;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Instant,
|
||||
};
|
||||
@@ -21,7 +22,6 @@ use std::{
|
||||
use bimap::BiMap;
|
||||
pub use client::{ClientState, Request};
|
||||
pub use gateway::GatewayState;
|
||||
pub use sockets::Sockets;
|
||||
use utils::turn;
|
||||
|
||||
mod client;
|
||||
@@ -50,7 +50,7 @@ pub type ClientTunnel<CB> = Tunnel<CB, ClientState>;
|
||||
/// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway.
|
||||
///
|
||||
/// Most of connlib's functionality is implemented as a pure state machine in [`ClientState`] and [`GatewayState`].
|
||||
/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`] or time and pass it to the respective state.
|
||||
/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`](crate::sockets::Sockets) or time and pass it to the respective state.
|
||||
pub struct Tunnel<CB: Callbacks, TRoleState> {
|
||||
pub callbacks: CB,
|
||||
|
||||
@@ -77,12 +77,13 @@ where
|
||||
{
|
||||
pub fn new(
|
||||
private_key: StaticSecret,
|
||||
sockets: Sockets,
|
||||
tcp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::TcpSocket>>,
|
||||
udp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::UdpSocket>>,
|
||||
callbacks: CB,
|
||||
known_hosts: HashMap<String, Vec<IpAddr>>,
|
||||
) -> std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
io: Io::new(sockets)?,
|
||||
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
|
||||
callbacks,
|
||||
role_state: ClientState::new(private_key, known_hosts),
|
||||
write_buf: Box::new([0u8; MTU + 16 + 20]),
|
||||
@@ -94,7 +95,7 @@ where
|
||||
|
||||
pub fn reset(&mut self) -> std::io::Result<()> {
|
||||
self.role_state.reset();
|
||||
self.io.sockets_mut().rebind()?;
|
||||
self.io.rebind_sockets()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -178,13 +179,9 @@ impl<CB> GatewayTunnel<CB>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn new(
|
||||
private_key: StaticSecret,
|
||||
sockets: Sockets,
|
||||
callbacks: CB,
|
||||
) -> std::io::Result<Self> {
|
||||
pub fn new(private_key: StaticSecret, callbacks: CB) -> std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
io: Io::new(sockets)?,
|
||||
io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?,
|
||||
callbacks,
|
||||
role_state: GatewayState::new(private_key),
|
||||
write_buf: Box::new([0u8; MTU + 20 + 16]),
|
||||
|
||||
@@ -1,61 +1,28 @@
|
||||
use core::slice;
|
||||
use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState};
|
||||
use socket2::{SockAddr, Type};
|
||||
use socket_factory::SocketFactory;
|
||||
use std::{
|
||||
io::{self, IoSliceMut},
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::{io::Interest, net::UdpSocket};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Sockets {
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Sockets {
|
||||
socket_v4: Option<Socket>,
|
||||
socket_v6: Option<Socket>,
|
||||
|
||||
#[cfg(unix)]
|
||||
protect: Box<dyn Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static>,
|
||||
}
|
||||
|
||||
impl Default for Sockets {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sockets {
|
||||
#[cfg(unix)]
|
||||
pub fn with_protect(
|
||||
protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket_v4: None,
|
||||
socket_v6: None,
|
||||
#[cfg(unix)]
|
||||
protect: Box::new(protect),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
socket_v4: None,
|
||||
socket_v6: None,
|
||||
#[cfg(unix)]
|
||||
protect: Box::new(|_| Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
|
||||
match addr {
|
||||
SocketAddr::V4(_) => self.socket_v4.is_some(),
|
||||
SocketAddr::V6(_) => self.socket_v6.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rebind(&mut self) -> io::Result<()> {
|
||||
let socket_v4 = Socket::ip4();
|
||||
let socket_v6 = Socket::ip6();
|
||||
pub fn rebind(
|
||||
&mut self,
|
||||
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
|
||||
) -> io::Result<()> {
|
||||
let socket_v4 = Socket::ip4(socket_factory);
|
||||
let socket_v6 = Socket::ip6(socket_factory);
|
||||
|
||||
match (socket_v4.as_ref(), socket_v6.as_ref()) {
|
||||
(Err(e), Ok(_)) => {
|
||||
@@ -76,19 +43,6 @@ impl Sockets {
|
||||
_ => (),
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) {
|
||||
(self.protect)(fd)?;
|
||||
}
|
||||
|
||||
if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) {
|
||||
(self.protect)(fd)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.socket_v4 = socket_v4.ok();
|
||||
self.socket_v6 = socket_v6.ok();
|
||||
|
||||
@@ -216,28 +170,33 @@ struct Socket {
|
||||
}
|
||||
|
||||
impl Socket {
|
||||
fn ip4() -> Result<Socket> {
|
||||
let socket = make_socket(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))?;
|
||||
fn ip(
|
||||
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
|
||||
addr: &SocketAddr,
|
||||
) -> Result<Socket> {
|
||||
let socket = socket_factory(addr)?;
|
||||
let port = socket.local_addr()?.port();
|
||||
|
||||
Ok(Socket {
|
||||
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
|
||||
port,
|
||||
socket: tokio::net::UdpSocket::from_std(socket)?,
|
||||
socket,
|
||||
buffered_transmits: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn ip6() -> Result<Socket> {
|
||||
let socket = make_socket(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))?;
|
||||
let port = socket.local_addr()?.port();
|
||||
fn ip4(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
|
||||
Self::ip(
|
||||
socket_factory,
|
||||
&SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)),
|
||||
)
|
||||
}
|
||||
|
||||
Ok(Socket {
|
||||
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
|
||||
port,
|
||||
socket: tokio::net::UdpSocket::from_std(socket)?,
|
||||
buffered_transmits: Vec::new(),
|
||||
})
|
||||
fn ip6(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
|
||||
Self::ip(
|
||||
socket_factory,
|
||||
&SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)),
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -332,25 +291,3 @@ impl Socket {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn make_socket(addr: impl Into<SocketAddr>) -> Result<std::net::UdpSocket> {
|
||||
let addr: SockAddr = addr.into().into();
|
||||
let socket = socket2::Socket::new(addr.domain(), Type::DGRAM, None)?;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `TunDeviceManager` until #5797.
|
||||
|
||||
socket.set_mark(FIREZONE_MARK)?;
|
||||
}
|
||||
|
||||
// Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag
|
||||
if addr.is_ipv6() {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
socket.set_nonblocking(true)?;
|
||||
socket.bind(&addr)?;
|
||||
|
||||
Ok(socket.into())
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ phoenix-channel = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0", default-features = false, features = ["std", "derive"] }
|
||||
snownet = { workspace = true }
|
||||
socket-factory = { workspace = true }
|
||||
static_assertions = "1.1.0"
|
||||
tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread", "fs", "signal"] }
|
||||
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }
|
||||
|
||||
@@ -6,7 +6,8 @@ use connlib_shared::{
|
||||
get_user_agent, keypair, messages::Interface, Callbacks, LoginUrl, StaticSecret,
|
||||
};
|
||||
use firezone_bin_shared::{setup_global_subscriber, CommonArgs, TunDeviceManager};
|
||||
use firezone_tunnel::{GatewayTunnel, Sockets, Tun};
|
||||
use firezone_tunnel::{GatewayTunnel, Tun};
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures::{future, StreamExt, TryFutureExt};
|
||||
use ip_network::{Ipv4Network, Ipv6Network};
|
||||
@@ -15,6 +16,7 @@ use secrecy::{Secret, SecretString};
|
||||
use std::convert::Infallible;
|
||||
use std::path::Path;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::signal::ctrl_c;
|
||||
use tracing_subscriber::layer;
|
||||
@@ -100,7 +102,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
|
||||
}
|
||||
|
||||
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
|
||||
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?;
|
||||
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
|
||||
let portal = PhoenixChannel::connect(
|
||||
Secret::new(login),
|
||||
get_user_agent(None, env!("CARGO_PKG_VERSION")),
|
||||
@@ -109,6 +111,7 @@ async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
|
||||
ExponentialBackoffBuilder::default()
|
||||
.with_max_elapsed_time(None)
|
||||
.build(),
|
||||
Arc::new(socket_factory::tcp),
|
||||
);
|
||||
|
||||
let (sender, receiver) = mpsc::channel::<Interface>(10);
|
||||
|
||||
@@ -20,6 +20,7 @@ ip_network = { version = "0.4", default-features = false }
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0.203", features = ["derive"] }
|
||||
serde_json = "1.0.117"
|
||||
socket-factory = { workspace = true }
|
||||
thiserror = { version = "1.0", default-features = false }
|
||||
# This actually relies on many other features in Tokio, so this will probably
|
||||
# fail to build outside the workspace. <https://github.com/firezone/firezone/pull/4328#discussion_r1540342142>
|
||||
@@ -40,7 +41,7 @@ mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` test
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
dirs = "5.0.1"
|
||||
libc = "0.2.150"
|
||||
nix = { version = "0.28.0", features = ["fs", "user"] }
|
||||
nix = { version = "0.28.0", features = ["fs", "user", "socket"] }
|
||||
resolv-conf = "0.7.0"
|
||||
rtnetlink = { workspace = true }
|
||||
sd-notify = "0.4.1" # This is a pure Rust re-implementation, so it isn't vulnerable to CVE-2024-3094
|
||||
|
||||
@@ -6,13 +6,13 @@ use crate::{
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use clap::Parser;
|
||||
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets};
|
||||
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session};
|
||||
use futures::{
|
||||
future::poll_fn,
|
||||
task::{Context, Poll},
|
||||
Future as _, SinkExt as _, Stream as _,
|
||||
};
|
||||
use std::{net::IpAddr, path::PathBuf, pin::pin, time::Duration};
|
||||
use std::{net::IpAddr, path::PathBuf, pin::pin, sync::Arc, time::Duration};
|
||||
use tokio::{sync::mpsc, time::Instant};
|
||||
use tracing::subscriber::set_global_default;
|
||||
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry};
|
||||
@@ -341,7 +341,8 @@ impl Handler {
|
||||
self.last_connlib_start_instant = Some(Instant::now());
|
||||
let args = ConnectArgs {
|
||||
url,
|
||||
sockets: Sockets::new(),
|
||||
tcp_socket_factory: Arc::new(crate::tcp_socket_factory),
|
||||
udp_socket_factory: Arc::new(crate::udp_socket_factory),
|
||||
private_key,
|
||||
os_version_override: None,
|
||||
app_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
|
||||
@@ -20,6 +20,8 @@ use tracing::subscriber::set_global_default;
|
||||
use tracing_subscriber::{fmt, layer::SubscriberExt as _, EnvFilter, Layer as _, Registry};
|
||||
|
||||
use platform::default_token_path;
|
||||
use platform::tcp_socket_factory;
|
||||
use platform::udp_socket_factory;
|
||||
|
||||
/// Generate a persistent device ID, stores it to disk, and reads it back.
|
||||
pub(crate) mod device_id;
|
||||
|
||||
@@ -2,13 +2,31 @@
|
||||
|
||||
use super::TOKEN_ENV_KEY;
|
||||
use anyhow::{bail, Result};
|
||||
use std::path::{Path, PathBuf};
|
||||
use firezone_bin_shared::FIREZONE_MARK;
|
||||
use nix::sys::socket::{setsockopt, sockopt};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
// The Client currently must run as root to control DNS
|
||||
// Root group and user are used to check file ownership on the token
|
||||
const ROOT_GROUP: u32 = 0;
|
||||
const ROOT_USER: u32 = 0;
|
||||
|
||||
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::TcpSocket> {
|
||||
let socket = socket_factory::tcp(socket_addr)?;
|
||||
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::UdpSocket> {
|
||||
let socket = socket_factory::udp(socket_addr)?;
|
||||
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
|
||||
Ok(socket)
|
||||
}
|
||||
|
||||
pub(crate) fn default_token_path() -> PathBuf {
|
||||
PathBuf::from("/etc")
|
||||
.join(connlib_shared::BUNDLE_ID)
|
||||
|
||||
@@ -6,13 +6,14 @@ use crate::{
|
||||
};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use clap::Parser;
|
||||
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets};
|
||||
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session};
|
||||
use firezone_bin_shared::{setup_global_subscriber, TunDeviceManager};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use secrecy::SecretString;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
pin::pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
@@ -156,7 +157,8 @@ pub fn run_only_headless_client() -> Result<()> {
|
||||
platform::setup_before_connlib()?;
|
||||
let args = ConnectArgs {
|
||||
url,
|
||||
sockets: Sockets::new(),
|
||||
udp_socket_factory: Arc::new(crate::udp_socket_factory),
|
||||
tcp_socket_factory: Arc::new(crate::tcp_socket_factory),
|
||||
private_key,
|
||||
os_version_override: None,
|
||||
app_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
use anyhow::Result;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub(crate) use socket_factory::tcp as tcp_socket_factory;
|
||||
pub(crate) use socket_factory::udp as udp_socket_factory;
|
||||
|
||||
#[path = "windows/wintun_install.rs"]
|
||||
mod wintun_install;
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ secrecy = { workspace = true }
|
||||
serde = { version = "1.0.203", features = ["derive"] }
|
||||
serde_json = "1.0.117"
|
||||
sha2 = "0.10.8"
|
||||
socket-factory = { workspace = true }
|
||||
thiserror = "1.0.61"
|
||||
tokio = { workspace = true, features = ["net", "time"] }
|
||||
tokio-tungstenite = { workspace = true, features = ["rustls-tls-webpki-roots"] }
|
||||
|
||||
@@ -3,6 +3,7 @@ mod login_url;
|
||||
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
use std::mem;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt, future, marker::PhantomData};
|
||||
@@ -16,15 +17,16 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat};
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use socket_factory::SocketFactory;
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_tungstenite::connect_async_with_config;
|
||||
use tokio_tungstenite::client_async_tls;
|
||||
use tokio_tungstenite::tungstenite::http::StatusCode;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{handshake::client::Request, Message},
|
||||
MaybeTlsStream, WebSocketStream,
|
||||
};
|
||||
use url::{Host, Url};
|
||||
|
||||
pub use login_url::{LoginUrl, LoginUrlError};
|
||||
|
||||
@@ -33,6 +35,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
|
||||
waker: Option<Waker>,
|
||||
pending_messages: VecDeque<String>,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
|
||||
heartbeat: Heartbeat,
|
||||
|
||||
@@ -59,17 +62,70 @@ enum State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn connect(url: Secret<LoginUrl>, user_agent: String) -> Self {
|
||||
Self::Connecting(Box::pin(async move {
|
||||
let (stream, _) = connect_async_with_config(make_request(url, user_agent), None, true)
|
||||
.await
|
||||
.map_err(InternalError::WebSocket)?;
|
||||
|
||||
Ok(stream)
|
||||
}))
|
||||
fn connect(
|
||||
url: Secret<LoginUrl>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
) -> Self {
|
||||
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_and_connect_websocket(
|
||||
url: Secret<LoginUrl>,
|
||||
user_agent: String,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
|
||||
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
|
||||
|
||||
let (stream, _) = client_async_tls(make_request(url, user_agent), socket)
|
||||
.await
|
||||
.map_err(InternalError::WebSocket)?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn make_socket(
|
||||
url: &Url,
|
||||
socket_factory: &dyn SocketFactory<tokio::net::TcpSocket>,
|
||||
) -> Result<TcpStream, InternalError> {
|
||||
let port = url
|
||||
.port_or_known_default()
|
||||
.expect("scheme to be http, https, ws or wss");
|
||||
let addrs: Vec<SocketAddr> = match url.host().ok_or(InternalError::InvalidUrl)? {
|
||||
Host::Domain(n) => tokio::net::lookup_host((n, port))
|
||||
.await
|
||||
.map_err(|_| InternalError::InvalidUrl)?
|
||||
.collect(),
|
||||
Host::Ipv6(ip) => {
|
||||
vec![(ip, port).into()]
|
||||
}
|
||||
Host::Ipv4(ip) => {
|
||||
vec![(ip, port).into()]
|
||||
}
|
||||
};
|
||||
|
||||
let mut last_error = None;
|
||||
for addr in addrs {
|
||||
let Ok(socket) = socket_factory(&addr) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
match socket.connect(addr).await {
|
||||
Ok(socket) => return Ok(socket),
|
||||
Err(e) => {
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let Some(err) = last_error else {
|
||||
return Err(InternalError::InvalidUrl);
|
||||
};
|
||||
|
||||
Err(InternalError::SocketConnection(err))
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("client error: {0}")]
|
||||
@@ -99,6 +155,8 @@ enum InternalError {
|
||||
MissedHeartbeat,
|
||||
CloseMessage,
|
||||
StreamClosed,
|
||||
InvalidUrl,
|
||||
SocketConnection(std::io::Error),
|
||||
}
|
||||
|
||||
impl fmt::Display for InternalError {
|
||||
@@ -119,6 +177,8 @@ impl fmt::Display for InternalError {
|
||||
InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"),
|
||||
InternalError::CloseMessage => write!(f, "portal closed the websocket connection"),
|
||||
InternalError::StreamClosed => write!(f, "websocket stream was closed"),
|
||||
InternalError::InvalidUrl => write!(f, "failed to resolve url"),
|
||||
InternalError::SocketConnection(e) => write!(f, "failed to connect socket: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,14 +221,13 @@ where
|
||||
///
|
||||
/// The provided URL must contain a host.
|
||||
/// Additionally, you must already provide any query parameters required for authentication.
|
||||
///
|
||||
/// Once the connection is established,
|
||||
pub fn connect(
|
||||
url: Secret<LoginUrl>,
|
||||
user_agent: String,
|
||||
login: &'static str,
|
||||
init_req: TInitReq,
|
||||
reconnect_backoff: ExponentialBackoff,
|
||||
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
|
||||
) -> Self {
|
||||
let next_request_id = Arc::new(AtomicU64::new(0));
|
||||
|
||||
@@ -178,7 +237,8 @@ where
|
||||
reconnect_backoff,
|
||||
url: url.clone(),
|
||||
user_agent: user_agent.clone(),
|
||||
state: State::connect(url, user_agent),
|
||||
state: State::connect(url, user_agent, socket_factory.clone()),
|
||||
socket_factory,
|
||||
waker: None,
|
||||
pending_messages: Default::default(),
|
||||
_phantom: PhantomData,
|
||||
@@ -220,7 +280,7 @@ where
|
||||
// 2. Set state to `Connecting` without a timer.
|
||||
let url = self.url.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
self.state = State::connect(url, user_agent);
|
||||
self.state = State::connect(url, user_agent, self.socket_factory.clone());
|
||||
|
||||
// 3. In case we were already re-connecting, we need to wake the suspended task.
|
||||
if let Some(waker) = self.waker.take() {
|
||||
@@ -293,18 +353,16 @@ where
|
||||
|
||||
let secret_url = self.url.clone();
|
||||
let user_agent = self.user_agent.clone();
|
||||
let socket_factory = self.socket_factory.clone();
|
||||
|
||||
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
|
||||
|
||||
self.state = State::Connecting(Box::pin(async move {
|
||||
tokio::time::sleep(backoff).await;
|
||||
|
||||
let (stream, _) = connect_async(make_request(secret_url, user_agent))
|
||||
create_and_connect_websocket(secret_url, user_agent, socket_factory)
|
||||
.await
|
||||
.map_err(InternalError::WebSocket)?;
|
||||
|
||||
Ok(stream)
|
||||
}));
|
||||
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {
|
||||
|
||||
@@ -26,7 +26,8 @@ rand = "0.8.5"
|
||||
secrecy = { workspace = true }
|
||||
serde = { version = "1.0.203", features = ["derive"] }
|
||||
sha2 = "0.10.8"
|
||||
socket2 = "0.5.7"
|
||||
socket-factory = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
stun_codec = "0.3.4"
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "signal"] }
|
||||
tracing = { workspace = true, features = ["log"] }
|
||||
|
||||
@@ -146,6 +146,7 @@ async fn main() -> Result<()> {
|
||||
ExponentialBackoffBuilder::default()
|
||||
.with_max_elapsed_time(Some(MAX_PARTITION_TIME))
|
||||
.build(),
|
||||
Arc::new(socket_factory::tcp),
|
||||
))
|
||||
} else {
|
||||
tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode");
|
||||
|
||||
8
rust/socket-factory/Cargo.toml
Normal file
8
rust/socket-factory/Cargo.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[package]
|
||||
name = "socket-factory"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
socket2 = { workspace = true }
|
||||
tokio = { version = "1.38", features = ["net"] }
|
||||
35
rust/socket-factory/src/lib.rs
Normal file
35
rust/socket-factory/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use socket2::SockAddr;
|
||||
|
||||
pub trait SocketFactory<S>: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static {}
|
||||
|
||||
impl<F, S> SocketFactory<S> for F where
|
||||
F: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static
|
||||
{
|
||||
}
|
||||
|
||||
pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
|
||||
let socket = match addr {
|
||||
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
|
||||
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
|
||||
};
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
Ok(socket)
|
||||
}
|
||||
pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
|
||||
let addr: SockAddr = (*addr).into();
|
||||
let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?;
|
||||
|
||||
// Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag
|
||||
if addr.is_ipv6() {
|
||||
socket.set_only_v6(true)?;
|
||||
}
|
||||
|
||||
socket.set_nonblocking(true)?;
|
||||
socket.bind(&addr)?;
|
||||
|
||||
std::net::UdpSocket::from(socket).try_into()
|
||||
}
|
||||
Reference in New Issue
Block a user