From 5b0aaa6f819cae23d1149ed079268ea3db555a54 Mon Sep 17 00:00:00 2001 From: Gabi Date: Mon, 15 Jul 2024 21:40:05 -0300 Subject: [PATCH] fix(connlib): protect all sockets from routing loops (#5797) Currently, only connlib's UDP sockets for sending and receiving STUN & WireGuard traffic are protected from routing loops. This is was done via the `Sockets::with_protect` function. Connlib has additional sockets though: - A TCP socket to the portal. - UDP & TCP sockets for DNS resolution via hickory. Both of these can incur routing loops on certain platforms which becomes evident as we try to implement #2667. To fix this, we generalise the idea of "protecting" a socket via a `SocketFactory` abstraction. By allowing the different platforms to provide a specialised `SocketFactory`, anything Linux-based can give special treatment to the socket before handing it to connlib. As an additional benefit, this allows us to remove the `Sockets` abstraction from connlib's API again because we can now initialise it internally via the provided `SocketFactory` for UDP sockets. --------- Signed-off-by: Gabi Co-authored-by: Thomas Eizinger --- rust/Cargo.lock | 17 +++ rust/Cargo.toml | 3 + rust/bin-shared/src/lib.rs | 3 + .../src/tun_device_manager/linux.rs | 3 +- rust/connlib/clients/android/Cargo.toml | 1 + rust/connlib/clients/android/src/lib.rs | 42 ++++--- rust/connlib/clients/apple/Cargo.toml | 1 + rust/connlib/clients/apple/src/lib.rs | 5 +- rust/connlib/clients/shared/Cargo.toml | 1 + rust/connlib/clients/shared/src/lib.rs | 28 +++-- rust/connlib/tunnel/Cargo.toml | 3 +- rust/connlib/tunnel/src/io.rs | 105 ++++++++++++++-- rust/connlib/tunnel/src/lib.rs | 19 ++- rust/connlib/tunnel/src/sockets.rs | 117 ++++-------------- rust/gateway/Cargo.toml | 1 + rust/gateway/src/main.rs | 7 +- rust/headless-client/Cargo.toml | 3 +- rust/headless-client/src/ipc_service.rs | 7 +- rust/headless-client/src/lib.rs | 2 + rust/headless-client/src/linux.rs | 20 ++- rust/headless-client/src/standalone.rs | 6 +- rust/headless-client/src/windows.rs | 3 + rust/phoenix-channel/Cargo.toml | 1 + rust/phoenix-channel/src/lib.rs | 96 +++++++++++--- rust/relay/Cargo.toml | 3 +- rust/relay/src/main.rs | 1 + rust/socket-factory/Cargo.toml | 8 ++ rust/socket-factory/src/lib.rs | 35 ++++++ 28 files changed, 374 insertions(+), 167 deletions(-) create mode 100644 rust/socket-factory/Cargo.toml create mode 100644 rust/socket-factory/src/lib.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 42f8c755c..66703e4ab 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1066,6 +1066,7 @@ dependencies = [ "log", "secrecy", "serde_json", + "socket-factory", "thiserror", "tokio", "tracing", @@ -1084,6 +1085,7 @@ dependencies = [ "oslog", "secrecy", "serde_json", + "socket-factory", "swift-bridge", "swift-bridge-build", "tokio", @@ -1109,6 +1111,7 @@ dependencies = [ "secrecy", "serde", "serde_json", + "socket-factory", "time", "tokio", "tokio-tungstenite", @@ -1873,6 +1876,7 @@ dependencies = [ "serde", "serde_json", "snownet", + "socket-factory", "static_assertions", "tokio", "tokio-tungstenite", @@ -1962,6 +1966,7 @@ dependencies = [ "secrecy", "serde", "serde_json", + "socket-factory", "tempfile", "thiserror", "tokio", @@ -2003,6 +2008,7 @@ dependencies = [ "secrecy", "serde", "sha2", + "socket-factory", "socket2 0.5.7", "stun_codec", "test-strategy", @@ -2054,6 +2060,7 @@ dependencies = [ "serde", "serde_json", "snownet", + "socket-factory", "socket2 0.5.7", "test-strategy", "thiserror", @@ -3901,6 +3908,7 @@ dependencies = [ "cfg-if", "cfg_aliases", "libc", + "memoffset 0.9.1", ] [[package]] @@ -4561,6 +4569,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "socket-factory", "thiserror", "tokio", "tokio-tungstenite", @@ -5812,6 +5821,14 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "socket-factory" +version = "0.1.0" +dependencies = [ + "socket2 0.5.7", + "tokio", +] + [[package]] name = "socket2" version = "0.4.10" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index e2a6af41a..85c161507 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,6 +17,7 @@ members = [ "phoenix-channel", "relay", "snownet-tests", + "socket-factory", ] resolver = "2" @@ -53,6 +54,8 @@ firezone-tunnel = { path = "connlib/tunnel" } phoenix-channel = { path = "phoenix-channel" } http-health-check = { path = "http-health-check" } ip-packet = { path = "ip-packet" } +socket-factory = { path = "socket-factory" } +socket2 = { version = "0.5" } [workspace.lints.clippy] dbg_macro = "warn" diff --git a/rust/bin-shared/src/lib.rs b/rust/bin-shared/src/lib.rs index 388782682..0ad0dfc19 100644 --- a/rust/bin-shared/src/lib.rs +++ b/rust/bin-shared/src/lib.rs @@ -7,6 +7,9 @@ use tracing_subscriber::{ }; use url::Url; +/// Mark for Firezone sockets to prevent routing loops on Linux. +pub const FIREZONE_MARK: u32 = 0xfd002021; + #[cfg(any(target_os = "linux", target_os = "windows"))] pub use tun_device_manager::TunDeviceManager; diff --git a/rust/bin-shared/src/tun_device_manager/linux.rs b/rust/bin-shared/src/tun_device_manager/linux.rs index 8049cd931..b298b42dc 100644 --- a/rust/bin-shared/src/tun_device_manager/linux.rs +++ b/rust/bin-shared/src/tun_device_manager/linux.rs @@ -13,7 +13,8 @@ use std::{ net::{Ipv4Addr, Ipv6Addr}, }; -const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `Sockets` until #5797. +use crate::FIREZONE_MARK; + const FILE_ALREADY_EXISTS: i32 = -17; const FIREZONE_TABLE: u32 = 0x2021_fd00; diff --git a/rust/connlib/clients/android/Cargo.toml b/rust/connlib/clients/android/Cargo.toml index 8d36e1a44..0ab3fe1e6 100644 --- a/rust/connlib/clients/android/Cargo.toml +++ b/rust/connlib/clients/android/Cargo.toml @@ -19,6 +19,7 @@ jni = { version = "0.21.1", features = ["invocation"] } log = "0.4" secrecy = { workspace = true } serde_json = "1" +socket-factory = { workspace = true } thiserror = "1" tokio = { workspace = true, features = ["rt"] } tracing = { workspace = true, features = ["std", "attributes"] } diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 0996b923c..40782fddc 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -5,7 +5,7 @@ use connlib_client_shared::{ callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl, - LoginUrlError, Session, Sockets, Tun, V4RouteList, V6RouteList, + LoginUrlError, Session, Tun, V4RouteList, V6RouteList, }; use ip_network::{Ipv4Network, Ipv6Network}; use jni::{ @@ -15,7 +15,8 @@ use jni::{ JNIEnv, JavaVM, }; use secrecy::SecretString; -use std::{io, net::IpAddr, path::Path}; +use socket_factory::SocketFactory; +use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc}; use std::{ net::{Ipv4Addr, Ipv6Addr}, os::fd::RawFd, @@ -86,16 +87,17 @@ impl CallbackHandler { .and_then(f) } - fn protect_file_descriptor(&self, file_descriptor: RawFd) -> Result<(), CallbackError> { + fn protect(&self, socket: RawFd) -> io::Result<()> { self.env(|mut env| { call_method( &mut env, &self.callback_handler, "protectFileDescriptor", "(I)V", - &[JValue::Int(file_descriptor)], + &[JValue::Int(socket)], ) }) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) } } @@ -357,18 +359,10 @@ fn connect( .enable_all() .build()?; - let sockets = Sockets::with_protect({ - let callbacks = callbacks.clone(); - move |fd| { - callbacks - .protect_file_descriptor(fd) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - } - }); - let args = ConnectArgs { url, - sockets, + tcp_socket_factory: Arc::new(protected_tcp_socket_factory(callbacks.clone())), + udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())), private_key, os_version_override: Some(os_version), app_version: env!("CARGO_PKG_VERSION").to_string(), @@ -523,3 +517,23 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se session.inner.set_tun(tun); } + +fn protected_tcp_socket_factory( + callbacks: CallbackHandler, +) -> impl SocketFactory { + move |addr| { + let socket = socket_factory::tcp(addr)?; + callbacks.protect(socket.as_raw_fd())?; + Ok(socket) + } +} + +fn protected_udp_socket_factory( + callbacks: CallbackHandler, +) -> impl SocketFactory { + move |addr| { + let socket = socket_factory::udp(addr)?; + callbacks.protect(socket.as_raw_fd())?; + Ok(socket) + } +} diff --git a/rust/connlib/clients/apple/Cargo.toml b/rust/connlib/clients/apple/Cargo.toml index 77032c504..ec7ad60e0 100644 --- a/rust/connlib/clients/apple/Cargo.toml +++ b/rust/connlib/clients/apple/Cargo.toml @@ -16,6 +16,7 @@ ip_network = "0.4" libc = "0.2" secrecy = { workspace = true } serde_json = "1" +socket-factory = { workspace = true } swift-bridge = { workspace = true } tokio = { workspace = true, features = ["rt"] } tracing = { workspace = true } diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 65ef126ec..b01bde9e5 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -5,7 +5,7 @@ mod make_writer; use connlib_client_shared::{ callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl, - Session, Sockets, Tun, V4RouteList, V6RouteList, + Session, Tun, V4RouteList, V6RouteList, }; use ip_network::{Ipv4Network, Ipv6Network}; use secrecy::SecretString; @@ -194,7 +194,6 @@ impl WrappedSession { let args = ConnectArgs { url, - sockets: Sockets::new(), private_key, os_version_override, app_version: env!("CARGO_PKG_VERSION").to_string(), @@ -202,6 +201,8 @@ impl WrappedSession { inner: Arc::new(callback_handler), }, max_partition_time: Some(MAX_PARTITION_TIME), + tcp_socket_factory: Arc::new(socket_factory::tcp), + udp_socket_factory: Arc::new(socket_factory::udp), }; let session = Session::connect(args, runtime.handle().clone()); let _enter = runtime.enter(); diff --git a/rust/connlib/clients/shared/Cargo.toml b/rust/connlib/clients/shared/Cargo.toml index 96942854d..57ee33e95 100644 --- a/rust/connlib/clients/shared/Cargo.toml +++ b/rust/connlib/clients/shared/Cargo.toml @@ -17,6 +17,7 @@ ip_network = { version = "0.4", default-features = false } phoenix-channel = { workspace = true } secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["std", "derive"] } +socket-factory = { workspace = true } time = { version = "0.3.36", features = ["formatting"] } tokio = { workspace = true, features = ["sync"] } tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 4816e4554..f3110d265 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -5,15 +5,17 @@ pub use connlib_shared::{ callbacks, keypair, Callbacks, Error, LoginUrl, LoginUrlError, StaticSecret, }; pub use eventloop::Eventloop; -pub use firezone_tunnel::{Sockets, Tun}; +pub use firezone_tunnel::Tun; pub use tracing_appender::non_blocking::WorkerGuard; use backoff::ExponentialBackoffBuilder; use connlib_shared::get_user_agent; use firezone_tunnel::ClientTunnel; use phoenix_channel::PhoenixChannel; +use socket_factory::SocketFactory; use std::collections::HashMap; use std::net::IpAddr; +use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::UnboundedReceiver; @@ -39,7 +41,8 @@ pub struct Session { /// Arguments for `connect`, since Clippy said 8 args is too many pub struct ConnectArgs { pub url: LoginUrl, - pub sockets: Sockets, + pub tcp_socket_factory: Arc>, + pub udp_socket_factory: Arc>, pub private_key: StaticSecret, pub os_version_override: Option, pub app_version: String, @@ -120,11 +123,12 @@ where { let ConnectArgs { url, - sockets, private_key, os_version_override, app_version, callbacks, + udp_socket_factory, + tcp_socket_factory, max_partition_time, } = args; @@ -139,7 +143,8 @@ where let tunnel = ClientTunnel::new( private_key, - sockets, + tcp_socket_factory.clone(), + udp_socket_factory, callbacks, HashMap::from([(url.host().to_string(), addrs)]), )?; @@ -152,6 +157,7 @@ where ExponentialBackoffBuilder::default() .with_max_elapsed_time(max_partition_time) .build(), + tcp_socket_factory, ); let mut eventloop = Eventloop::new(tunnel, portal, rx); @@ -232,14 +238,18 @@ mod tests { #[cfg(any(target_os = "windows", target_os = "linux"))] async fn device_common() { use firezone_tunnel::Tun; - use std::collections::HashMap; + use std::{collections::HashMap, sync::Arc}; let (private_key, _public_key) = connlib_shared::keypair(); - let sockets = crate::Sockets::new(); let callbacks = Callbacks::default(); - let mut tunnel = - firezone_tunnel::ClientTunnel::new(private_key, sockets, callbacks, HashMap::new()) - .unwrap(); + let mut tunnel = firezone_tunnel::ClientTunnel::new( + private_key, + Arc::new(socket_factory::tcp), + Arc::new(socket_factory::udp), + callbacks, + HashMap::new(), + ) + .unwrap(); let upstream_dns = vec![([192, 168, 1, 1], 53).into()]; let interface = connlib_shared::messages::Interface { ipv4: [100, 71, 96, 96].into(), diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 69b13f2ce..8f8255cd6 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -30,7 +30,8 @@ rangemap = "1.5.1" secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["derive", "std"] } snownet = { workspace = true } -socket2 = { version = "0.5" } +socket-factory = { workspace = true } +socket2 = { workspace = true } thiserror = { version = "1.0", default-features = false } tokio = { workspace = true } tracing = { workspace = true } diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 8a9919ce2..bab1ea803 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -5,22 +5,29 @@ use crate::{ }; use bytes::Bytes; use connlib_shared::messages::DnsServer; +use futures::Future; use futures_bounded::FuturesTupleSet; use futures_util::FutureExt as _; +use hickory_proto::iocompat::AsyncIoTokioAsStd; +use hickory_proto::TokioTime; use hickory_resolver::{ config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}, - TokioAsyncResolver, + name_server::{GenericConnector, RuntimeProvider}, + AsyncResolver, TokioHandle, }; use ip_packet::{IpPacket, MutableIpPacket}; use quinn_udp::Transmit; +use socket_factory::SocketFactory; use std::{ collections::HashMap, io, - net::IpAddr, + net::{IpAddr, SocketAddr}, pin::Pin, + sync::Arc, task::{ready, Context, Poll}, time::{Duration, Instant}, }; +use tokio::net::{TcpSocket, UdpSocket}; const DNS_QUERIES_QUEUE_SIZE: usize = 100; @@ -32,9 +39,13 @@ pub struct Io { device: Device, /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, + + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + timeout: Option>>, - upstream_dns_servers: HashMap, + upstream_dns_servers: HashMap>>, forwarded_dns_queries: FuturesTupleSet< Result, DnsQuery<'static>, @@ -58,13 +69,19 @@ impl Io { /// Creates a new I/O abstraction /// /// Must be called within a Tokio runtime context so we can bind the sockets. - pub fn new(mut sockets: Sockets) -> io::Result { - sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context. + pub fn new( + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + ) -> io::Result { + let mut sockets = Sockets::default(); + sockets.rebind(udp_socket_factory.as_ref())?; // Bind sockets on startup. Must happen within a tokio runtime context. Ok(Self { device: Device::new(), timeout: None, sockets, + tcp_socket_factory, + udp_socket_factory, upstream_dns_servers: HashMap::default(), forwarded_dns_queries: FuturesTupleSet::new( Duration::from_secs(60), @@ -107,8 +124,10 @@ impl Io { &mut self.device } - pub fn sockets_mut(&mut self) -> &mut Sockets { - &mut self.sockets + pub fn rebind_sockets(&mut self) -> io::Result<()> { + self.sockets.rebind(self.udp_socket_factory.as_ref())?; + + Ok(()) } pub fn set_upstream_dns_servers( @@ -119,7 +138,13 @@ impl Io { self.forwarded_dns_queries = FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE); - self.upstream_dns_servers = create_resolvers(dns_servers); + self.upstream_dns_servers = create_resolvers( + dns_servers, + TokioRuntimeProvider::new( + self.tcp_socket_factory.clone(), + self.udp_socket_factory.clone(), + ), + ); } pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> { @@ -186,9 +211,65 @@ pub enum DnsQueryError { TooManyQueries, } +/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`]. +#[derive(Clone)] +struct TokioRuntimeProvider { + handle: TokioHandle, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, +} + +impl TokioRuntimeProvider { + fn new( + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + ) -> TokioRuntimeProvider { + Self { + handle: Default::default(), + tcp_socket_factory, + udp_socket_factory, + } + } +} + +impl RuntimeProvider for TokioRuntimeProvider { + type Handle = TokioHandle; + type Timer = TokioTime; + type Udp = UdpSocket; + type Tcp = AsyncIoTokioAsStd; + + fn create_handle(&self) -> Self::Handle { + self.handle.clone() + } + + fn connect_tcp( + &self, + server_addr: SocketAddr, + ) -> Pin>>> { + let socket = (self.tcp_socket_factory)(&server_addr); + Box::pin(async move { + let socket = socket?; + let stream = socket.connect(server_addr).await?; + + Ok(AsyncIoTokioAsStd(stream)) + }) + } + + fn bind_udp( + &self, + local_addr: SocketAddr, + _server_addr: SocketAddr, + ) -> Pin>>> { + let socket = (self.udp_socket_factory)(&local_addr); + + Box::pin(async move { socket }) + } +} + fn create_resolvers( dns_servers: impl IntoIterator, -) -> HashMap { + runtime_provider: TokioRuntimeProvider, +) -> HashMap>> { dns_servers .into_iter() .map(|(sentinel, srv)| { @@ -201,7 +282,11 @@ fn create_resolvers( ( sentinel, - TokioAsyncResolver::tokio(resolver_config, resolver_opts), + AsyncResolver::new_with_conn( + resolver_config, + resolver_opts, + GenericConnector::new(runtime_provider.clone()), + ), ) }) .collect() diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 6e1c31aac..c5118fbda 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -14,6 +14,7 @@ use io::Io; use std::{ collections::{HashMap, HashSet}, net::{IpAddr, SocketAddr}, + sync::Arc, task::{Context, Poll}, time::Instant, }; @@ -21,7 +22,6 @@ use std::{ use bimap::BiMap; pub use client::{ClientState, Request}; pub use gateway::GatewayState; -pub use sockets::Sockets; use utils::turn; mod client; @@ -50,7 +50,7 @@ pub type ClientTunnel = Tunnel; /// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway. /// /// Most of connlib's functionality is implemented as a pure state machine in [`ClientState`] and [`GatewayState`]. -/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`] or time and pass it to the respective state. +/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`](crate::sockets::Sockets) or time and pass it to the respective state. pub struct Tunnel { pub callbacks: CB, @@ -77,12 +77,13 @@ where { pub fn new( private_key: StaticSecret, - sockets: Sockets, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, callbacks: CB, known_hosts: HashMap>, ) -> std::io::Result { Ok(Self { - io: Io::new(sockets)?, + io: Io::new(tcp_socket_factory, udp_socket_factory)?, callbacks, role_state: ClientState::new(private_key, known_hosts), write_buf: Box::new([0u8; MTU + 16 + 20]), @@ -94,7 +95,7 @@ where pub fn reset(&mut self) -> std::io::Result<()> { self.role_state.reset(); - self.io.sockets_mut().rebind()?; + self.io.rebind_sockets()?; Ok(()) } @@ -178,13 +179,9 @@ impl GatewayTunnel where CB: Callbacks + 'static, { - pub fn new( - private_key: StaticSecret, - sockets: Sockets, - callbacks: CB, - ) -> std::io::Result { + pub fn new(private_key: StaticSecret, callbacks: CB) -> std::io::Result { Ok(Self { - io: Io::new(sockets)?, + io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?, callbacks, role_state: GatewayState::new(private_key), write_buf: Box::new([0u8; MTU + 20 + 16]), diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index d4855c5fa..0e9b4790a 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -1,61 +1,28 @@ use core::slice; use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState}; -use socket2::{SockAddr, Type}; +use socket_factory::SocketFactory; use std::{ io::{self, IoSliceMut}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, task::{ready, Context, Poll}, }; use tokio::{io::Interest, net::UdpSocket}; use crate::Result; -pub struct Sockets { +#[derive(Default)] +pub(crate) struct Sockets { socket_v4: Option, socket_v6: Option, - - #[cfg(unix)] - protect: Box io::Result<()> + Send + 'static>, -} - -impl Default for Sockets { - fn default() -> Self { - Self::new() - } } impl Sockets { - #[cfg(unix)] - pub fn with_protect( - protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static, - ) -> Self { - Self { - socket_v4: None, - socket_v6: None, - #[cfg(unix)] - protect: Box::new(protect), - } - } - - pub fn new() -> Self { - Self { - socket_v4: None, - socket_v6: None, - #[cfg(unix)] - protect: Box::new(|_| Ok(())), - } - } - - pub fn can_handle(&self, addr: &SocketAddr) -> bool { - match addr { - SocketAddr::V4(_) => self.socket_v4.is_some(), - SocketAddr::V6(_) => self.socket_v6.is_some(), - } - } - - pub fn rebind(&mut self) -> io::Result<()> { - let socket_v4 = Socket::ip4(); - let socket_v6 = Socket::ip6(); + pub fn rebind( + &mut self, + socket_factory: &dyn SocketFactory, + ) -> io::Result<()> { + let socket_v4 = Socket::ip4(socket_factory); + let socket_v6 = Socket::ip6(socket_factory); match (socket_v4.as_ref(), socket_v6.as_ref()) { (Err(e), Ok(_)) => { @@ -76,19 +43,6 @@ impl Sockets { _ => (), } - #[cfg(unix)] - { - use std::os::fd::AsRawFd; - - if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) { - (self.protect)(fd)?; - } - - if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) { - (self.protect)(fd)?; - } - } - self.socket_v4 = socket_v4.ok(); self.socket_v6 = socket_v6.ok(); @@ -216,28 +170,33 @@ struct Socket { } impl Socket { - fn ip4() -> Result { - let socket = make_socket(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))?; + fn ip( + socket_factory: &dyn SocketFactory, + addr: &SocketAddr, + ) -> Result { + let socket = socket_factory(addr)?; let port = socket.local_addr()?.port(); Ok(Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, - socket: tokio::net::UdpSocket::from_std(socket)?, + socket, buffered_transmits: Vec::new(), }) } - fn ip6() -> Result { - let socket = make_socket(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))?; - let port = socket.local_addr()?.port(); + fn ip4(socket_factory: &dyn SocketFactory) -> Result { + Self::ip( + socket_factory, + &SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)), + ) + } - Ok(Socket { - state: UdpSocketState::new(UdpSockRef::from(&socket))?, - port, - socket: tokio::net::UdpSocket::from_std(socket)?, - buffered_transmits: Vec::new(), - }) + fn ip6(socket_factory: &dyn SocketFactory) -> Result { + Self::ip( + socket_factory, + &SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)), + ) } #[allow(clippy::type_complexity)] @@ -332,25 +291,3 @@ impl Socket { ); } } - -fn make_socket(addr: impl Into) -> Result { - let addr: SockAddr = addr.into().into(); - let socket = socket2::Socket::new(addr.domain(), Type::DGRAM, None)?; - - #[cfg(target_os = "linux")] - { - const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `TunDeviceManager` until #5797. - - socket.set_mark(FIREZONE_MARK)?; - } - - // Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag - if addr.is_ipv6() { - socket.set_only_v6(true)?; - } - - socket.set_nonblocking(true)?; - socket.bind(&addr)?; - - Ok(socket.into()) -} diff --git a/rust/gateway/Cargo.toml b/rust/gateway/Cargo.toml index 5367b327c..91cd43cdf 100644 --- a/rust/gateway/Cargo.toml +++ b/rust/gateway/Cargo.toml @@ -27,6 +27,7 @@ phoenix-channel = { workspace = true } secrecy = { workspace = true } serde = { version = "1.0", default-features = false, features = ["std", "derive"] } snownet = { workspace = true } +socket-factory = { workspace = true } static_assertions = "1.1.0" tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread", "fs", "signal"] } tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] } diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 8bce7455c..c631ec8a5 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -6,7 +6,8 @@ use connlib_shared::{ get_user_agent, keypair, messages::Interface, Callbacks, LoginUrl, StaticSecret, }; use firezone_bin_shared::{setup_global_subscriber, CommonArgs, TunDeviceManager}; -use firezone_tunnel::{GatewayTunnel, Sockets, Tun}; +use firezone_tunnel::{GatewayTunnel, Tun}; + use futures::channel::mpsc; use futures::{future, StreamExt, TryFutureExt}; use ip_network::{Ipv4Network, Ipv6Network}; @@ -15,6 +16,7 @@ use secrecy::{Secret, SecretString}; use std::convert::Infallible; use std::path::Path; use std::pin::pin; +use std::sync::Arc; use tokio::io::AsyncWriteExt; use tokio::signal::ctrl_c; use tracing_subscriber::layer; @@ -100,7 +102,7 @@ async fn get_firezone_id(env_id: Option) -> Result { } async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { - let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?; + let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?; let portal = PhoenixChannel::connect( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")), @@ -109,6 +111,7 @@ async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { ExponentialBackoffBuilder::default() .with_max_elapsed_time(None) .build(), + Arc::new(socket_factory::tcp), ); let (sender, receiver) = mpsc::channel::(10); diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index a7a3b88c4..bc3b2ce9a 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -20,6 +20,7 @@ ip_network = { version = "0.4", default-features = false } secrecy = { workspace = true } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" +socket-factory = { workspace = true } thiserror = { version = "1.0", default-features = false } # This actually relies on many other features in Tokio, so this will probably # fail to build outside the workspace. @@ -40,7 +41,7 @@ mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` test [target.'cfg(target_os = "linux")'.dependencies] dirs = "5.0.1" libc = "0.2.150" -nix = { version = "0.28.0", features = ["fs", "user"] } +nix = { version = "0.28.0", features = ["fs", "user", "socket"] } resolv-conf = "0.7.0" rtnetlink = { workspace = true } sd-notify = "0.4.1" # This is a pure Rust re-implementation, so it isn't vulnerable to CVE-2024-3094 diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index cd8f355e4..cc62196bb 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -6,13 +6,13 @@ use crate::{ }; use anyhow::{Context as _, Result}; use clap::Parser; -use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets}; +use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session}; use futures::{ future::poll_fn, task::{Context, Poll}, Future as _, SinkExt as _, Stream as _, }; -use std::{net::IpAddr, path::PathBuf, pin::pin, time::Duration}; +use std::{net::IpAddr, path::PathBuf, pin::pin, sync::Arc, time::Duration}; use tokio::{sync::mpsc, time::Instant}; use tracing::subscriber::set_global_default; use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry}; @@ -341,7 +341,8 @@ impl Handler { self.last_connlib_start_instant = Some(Instant::now()); let args = ConnectArgs { url, - sockets: Sockets::new(), + tcp_socket_factory: Arc::new(crate::tcp_socket_factory), + udp_socket_factory: Arc::new(crate::udp_socket_factory), private_key, os_version_override: None, app_version: env!("CARGO_PKG_VERSION").to_string(), diff --git a/rust/headless-client/src/lib.rs b/rust/headless-client/src/lib.rs index 8d2892ac2..e3d50ab17 100644 --- a/rust/headless-client/src/lib.rs +++ b/rust/headless-client/src/lib.rs @@ -20,6 +20,8 @@ use tracing::subscriber::set_global_default; use tracing_subscriber::{fmt, layer::SubscriberExt as _, EnvFilter, Layer as _, Registry}; use platform::default_token_path; +use platform::tcp_socket_factory; +use platform::udp_socket_factory; /// Generate a persistent device ID, stores it to disk, and reads it back. pub(crate) mod device_id; diff --git a/rust/headless-client/src/linux.rs b/rust/headless-client/src/linux.rs index 48c925e72..7b9267739 100644 --- a/rust/headless-client/src/linux.rs +++ b/rust/headless-client/src/linux.rs @@ -2,13 +2,31 @@ use super::TOKEN_ENV_KEY; use anyhow::{bail, Result}; -use std::path::{Path, PathBuf}; +use firezone_bin_shared::FIREZONE_MARK; +use nix::sys::socket::{setsockopt, sockopt}; +use std::{ + io, + net::SocketAddr, + path::{Path, PathBuf}, +}; // The Client currently must run as root to control DNS // Root group and user are used to check file ownership on the token const ROOT_GROUP: u32 = 0; const ROOT_USER: u32 = 0; +pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result { + let socket = socket_factory::tcp(socket_addr)?; + setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; + Ok(socket) +} + +pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result { + let socket = socket_factory::udp(socket_addr)?; + setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?; + Ok(socket) +} + pub(crate) fn default_token_path() -> PathBuf { PathBuf::from("/etc") .join(connlib_shared::BUNDLE_ID) diff --git a/rust/headless-client/src/standalone.rs b/rust/headless-client/src/standalone.rs index 240395773..001650f54 100644 --- a/rust/headless-client/src/standalone.rs +++ b/rust/headless-client/src/standalone.rs @@ -6,13 +6,14 @@ use crate::{ }; use anyhow::{anyhow, Context as _, Result}; use clap::Parser; -use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets}; +use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session}; use firezone_bin_shared::{setup_global_subscriber, TunDeviceManager}; use futures::{FutureExt as _, StreamExt as _}; use secrecy::SecretString; use std::{ path::{Path, PathBuf}, pin::pin, + sync::Arc, }; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -156,7 +157,8 @@ pub fn run_only_headless_client() -> Result<()> { platform::setup_before_connlib()?; let args = ConnectArgs { url, - sockets: Sockets::new(), + udp_socket_factory: Arc::new(crate::udp_socket_factory), + tcp_socket_factory: Arc::new(crate::tcp_socket_factory), private_key, os_version_override: None, app_version: env!("CARGO_PKG_VERSION").to_string(), diff --git a/rust/headless-client/src/windows.rs b/rust/headless-client/src/windows.rs index e88c5847e..048075586 100644 --- a/rust/headless-client/src/windows.rs +++ b/rust/headless-client/src/windows.rs @@ -7,6 +7,9 @@ use anyhow::Result; use std::path::{Path, PathBuf}; +pub(crate) use socket_factory::tcp as tcp_socket_factory; +pub(crate) use socket_factory::udp as udp_socket_factory; + #[path = "windows/wintun_install.rs"] mod wintun_install; diff --git a/rust/phoenix-channel/Cargo.toml b/rust/phoenix-channel/Cargo.toml index 690aa91ab..27ef708b1 100644 --- a/rust/phoenix-channel/Cargo.toml +++ b/rust/phoenix-channel/Cargo.toml @@ -15,6 +15,7 @@ secrecy = { workspace = true } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" sha2 = "0.10.8" +socket-factory = { workspace = true } thiserror = "1.0.61" tokio = { workspace = true, features = ["net", "time"] } tokio-tungstenite = { workspace = true, features = ["rustls-tls-webpki-roots"] } diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 7bdcaf482..4ab57df5a 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -3,6 +3,7 @@ mod login_url; use std::collections::{HashSet, VecDeque}; use std::mem; +use std::net::SocketAddr; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::{fmt, future, marker::PhantomData}; @@ -16,15 +17,16 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat}; use rand_core::{OsRng, RngCore}; use secrecy::{ExposeSecret as _, Secret}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use socket_factory::SocketFactory; use std::task::{Context, Poll, Waker}; use tokio::net::TcpStream; -use tokio_tungstenite::connect_async_with_config; +use tokio_tungstenite::client_async_tls; use tokio_tungstenite::tungstenite::http::StatusCode; use tokio_tungstenite::{ - connect_async, tungstenite::{handshake::client::Request, Message}, MaybeTlsStream, WebSocketStream, }; +use url::{Host, Url}; pub use login_url::{LoginUrl, LoginUrlError}; @@ -33,6 +35,7 @@ pub struct PhoenixChannel { waker: Option, pending_messages: VecDeque, next_request_id: Arc, + socket_factory: Arc>, heartbeat: Heartbeat, @@ -59,17 +62,70 @@ enum State { } impl State { - fn connect(url: Secret, user_agent: String) -> Self { - Self::Connecting(Box::pin(async move { - let (stream, _) = connect_async_with_config(make_request(url, user_agent), None, true) - .await - .map_err(InternalError::WebSocket)?; - - Ok(stream) - })) + fn connect( + url: Secret, + user_agent: String, + socket_factory: Arc>, + ) -> Self { + Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed()) } } +async fn create_and_connect_websocket( + url: Secret, + user_agent: String, + socket_factory: Arc>, +) -> Result>, InternalError> { + let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?; + + let (stream, _) = client_async_tls(make_request(url, user_agent), socket) + .await + .map_err(InternalError::WebSocket)?; + + Ok(stream) +} + +async fn make_socket( + url: &Url, + socket_factory: &dyn SocketFactory, +) -> Result { + let port = url + .port_or_known_default() + .expect("scheme to be http, https, ws or wss"); + let addrs: Vec = match url.host().ok_or(InternalError::InvalidUrl)? { + Host::Domain(n) => tokio::net::lookup_host((n, port)) + .await + .map_err(|_| InternalError::InvalidUrl)? + .collect(), + Host::Ipv6(ip) => { + vec![(ip, port).into()] + } + Host::Ipv4(ip) => { + vec![(ip, port).into()] + } + }; + + let mut last_error = None; + for addr in addrs { + let Ok(socket) = socket_factory(&addr) else { + continue; + }; + + match socket.connect(addr).await { + Ok(socket) => return Ok(socket), + Err(e) => { + last_error = Some(e); + } + } + } + + let Some(err) = last_error else { + return Err(InternalError::InvalidUrl); + }; + + Err(InternalError::SocketConnection(err)) +} + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("client error: {0}")] @@ -99,6 +155,8 @@ enum InternalError { MissedHeartbeat, CloseMessage, StreamClosed, + InvalidUrl, + SocketConnection(std::io::Error), } impl fmt::Display for InternalError { @@ -119,6 +177,8 @@ impl fmt::Display for InternalError { InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"), InternalError::CloseMessage => write!(f, "portal closed the websocket connection"), InternalError::StreamClosed => write!(f, "websocket stream was closed"), + InternalError::InvalidUrl => write!(f, "failed to resolve url"), + InternalError::SocketConnection(e) => write!(f, "failed to connect socket: {e}"), } } } @@ -161,14 +221,13 @@ where /// /// The provided URL must contain a host. /// Additionally, you must already provide any query parameters required for authentication. - /// - /// Once the connection is established, pub fn connect( url: Secret, user_agent: String, login: &'static str, init_req: TInitReq, reconnect_backoff: ExponentialBackoff, + socket_factory: Arc>, ) -> Self { let next_request_id = Arc::new(AtomicU64::new(0)); @@ -178,7 +237,8 @@ where reconnect_backoff, url: url.clone(), user_agent: user_agent.clone(), - state: State::connect(url, user_agent), + state: State::connect(url, user_agent, socket_factory.clone()), + socket_factory, waker: None, pending_messages: Default::default(), _phantom: PhantomData, @@ -220,7 +280,7 @@ where // 2. Set state to `Connecting` without a timer. let url = self.url.clone(); let user_agent = self.user_agent.clone(); - self.state = State::connect(url, user_agent); + self.state = State::connect(url, user_agent, self.socket_factory.clone()); // 3. In case we were already re-connecting, we need to wake the suspended task. if let Some(waker) = self.waker.take() { @@ -293,18 +353,16 @@ where let secret_url = self.url.clone(); let user_agent = self.user_agent.clone(); + let socket_factory = self.socket_factory.clone(); tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}"); self.state = State::Connecting(Box::pin(async move { tokio::time::sleep(backoff).await; - - let (stream, _) = connect_async(make_request(secret_url, user_agent)) + create_and_connect_websocket(secret_url, user_agent, socket_factory) .await - .map_err(InternalError::WebSocket)?; - - Ok(stream) })); + continue; } Poll::Pending => { diff --git a/rust/relay/Cargo.toml b/rust/relay/Cargo.toml index fa629039c..2b6be6b3e 100644 --- a/rust/relay/Cargo.toml +++ b/rust/relay/Cargo.toml @@ -26,7 +26,8 @@ rand = "0.8.5" secrecy = { workspace = true } serde = { version = "1.0.203", features = ["derive"] } sha2 = "0.10.8" -socket2 = "0.5.7" +socket-factory = { workspace = true } +socket2 = { workspace = true } stun_codec = "0.3.4" tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "signal"] } tracing = { workspace = true, features = ["log"] } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 2fe23209c..d24d8c73c 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -146,6 +146,7 @@ async fn main() -> Result<()> { ExponentialBackoffBuilder::default() .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) .build(), + Arc::new(socket_factory::tcp), )) } else { tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode"); diff --git a/rust/socket-factory/Cargo.toml b/rust/socket-factory/Cargo.toml new file mode 100644 index 000000000..13712112f --- /dev/null +++ b/rust/socket-factory/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "socket-factory" +version = "0.1.0" +edition = "2021" + +[dependencies] +socket2 = { workspace = true } +tokio = { version = "1.38", features = ["net"] } diff --git a/rust/socket-factory/src/lib.rs b/rust/socket-factory/src/lib.rs new file mode 100644 index 000000000..db18ef752 --- /dev/null +++ b/rust/socket-factory/src/lib.rs @@ -0,0 +1,35 @@ +use std::net::SocketAddr; + +use socket2::SockAddr; + +pub trait SocketFactory: Fn(&SocketAddr) -> std::io::Result + Send + Sync + 'static {} + +impl SocketFactory for F where + F: Fn(&SocketAddr) -> std::io::Result + Send + Sync + 'static +{ +} + +pub fn tcp(addr: &SocketAddr) -> std::io::Result { + let socket = match addr { + SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?, + SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?, + }; + + socket.set_nodelay(true)?; + + Ok(socket) +} +pub fn udp(addr: &SocketAddr) -> std::io::Result { + let addr: SockAddr = (*addr).into(); + let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?; + + // Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag + if addr.is_ipv6() { + socket.set_only_v6(true)?; + } + + socket.set_nonblocking(true)?; + socket.bind(&addr)?; + + std::net::UdpSocket::from(socket).try_into() +}