fix(connlib): protect all sockets from routing loops (#5797)

Currently, only connlib's UDP sockets for sending and receiving STUN &
WireGuard traffic are protected from routing loops. This is was done via
the `Sockets::with_protect` function. Connlib has additional sockets
though:

- A TCP socket to the portal.
- UDP & TCP sockets for DNS resolution via hickory.

Both of these can incur routing loops on certain platforms which becomes
evident as we try to implement #2667.

To fix this, we generalise the idea of "protecting" a socket via a
`SocketFactory` abstraction. By allowing the different platforms to
provide a specialised `SocketFactory`, anything Linux-based can give
special treatment to the socket before handing it to connlib.

As an additional benefit, this allows us to remove the `Sockets`
abstraction from connlib's API again because we can now initialise it
internally via the provided `SocketFactory` for UDP sockets.

---------

Signed-off-by: Gabi <gabrielalejandro7@gmail.com>
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Gabi
2024-07-15 21:40:05 -03:00
committed by GitHub
parent 14abda01fd
commit 5b0aaa6f81
28 changed files with 374 additions and 167 deletions

17
rust/Cargo.lock generated
View File

@@ -1066,6 +1066,7 @@ dependencies = [
"log",
"secrecy",
"serde_json",
"socket-factory",
"thiserror",
"tokio",
"tracing",
@@ -1084,6 +1085,7 @@ dependencies = [
"oslog",
"secrecy",
"serde_json",
"socket-factory",
"swift-bridge",
"swift-bridge-build",
"tokio",
@@ -1109,6 +1111,7 @@ dependencies = [
"secrecy",
"serde",
"serde_json",
"socket-factory",
"time",
"tokio",
"tokio-tungstenite",
@@ -1873,6 +1876,7 @@ dependencies = [
"serde",
"serde_json",
"snownet",
"socket-factory",
"static_assertions",
"tokio",
"tokio-tungstenite",
@@ -1962,6 +1966,7 @@ dependencies = [
"secrecy",
"serde",
"serde_json",
"socket-factory",
"tempfile",
"thiserror",
"tokio",
@@ -2003,6 +2008,7 @@ dependencies = [
"secrecy",
"serde",
"sha2",
"socket-factory",
"socket2 0.5.7",
"stun_codec",
"test-strategy",
@@ -2054,6 +2060,7 @@ dependencies = [
"serde",
"serde_json",
"snownet",
"socket-factory",
"socket2 0.5.7",
"test-strategy",
"thiserror",
@@ -3901,6 +3908,7 @@ dependencies = [
"cfg-if",
"cfg_aliases",
"libc",
"memoffset 0.9.1",
]
[[package]]
@@ -4561,6 +4569,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"socket-factory",
"thiserror",
"tokio",
"tokio-tungstenite",
@@ -5812,6 +5821,14 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "socket-factory"
version = "0.1.0"
dependencies = [
"socket2 0.5.7",
"tokio",
]
[[package]]
name = "socket2"
version = "0.4.10"

View File

@@ -17,6 +17,7 @@ members = [
"phoenix-channel",
"relay",
"snownet-tests",
"socket-factory",
]
resolver = "2"
@@ -53,6 +54,8 @@ firezone-tunnel = { path = "connlib/tunnel" }
phoenix-channel = { path = "phoenix-channel" }
http-health-check = { path = "http-health-check" }
ip-packet = { path = "ip-packet" }
socket-factory = { path = "socket-factory" }
socket2 = { version = "0.5" }
[workspace.lints.clippy]
dbg_macro = "warn"

View File

@@ -7,6 +7,9 @@ use tracing_subscriber::{
};
use url::Url;
/// Mark for Firezone sockets to prevent routing loops on Linux.
pub const FIREZONE_MARK: u32 = 0xfd002021;
#[cfg(any(target_os = "linux", target_os = "windows"))]
pub use tun_device_manager::TunDeviceManager;

View File

@@ -13,7 +13,8 @@ use std::{
net::{Ipv4Addr, Ipv6Addr},
};
const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `Sockets` until #5797.
use crate::FIREZONE_MARK;
const FILE_ALREADY_EXISTS: i32 = -17;
const FIREZONE_TABLE: u32 = 0x2021_fd00;

View File

@@ -19,6 +19,7 @@ jni = { version = "0.21.1", features = ["invocation"] }
log = "0.4"
secrecy = { workspace = true }
serde_json = "1"
socket-factory = { workspace = true }
thiserror = "1"
tokio = { workspace = true, features = ["rt"] }
tracing = { workspace = true, features = ["std", "attributes"] }

View File

@@ -5,7 +5,7 @@
use connlib_client_shared::{
callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl,
LoginUrlError, Session, Sockets, Tun, V4RouteList, V6RouteList,
LoginUrlError, Session, Tun, V4RouteList, V6RouteList,
};
use ip_network::{Ipv4Network, Ipv6Network};
use jni::{
@@ -15,7 +15,8 @@ use jni::{
JNIEnv, JavaVM,
};
use secrecy::SecretString;
use std::{io, net::IpAddr, path::Path};
use socket_factory::SocketFactory;
use std::{io, net::IpAddr, os::fd::AsRawFd, path::Path, sync::Arc};
use std::{
net::{Ipv4Addr, Ipv6Addr},
os::fd::RawFd,
@@ -86,16 +87,17 @@ impl CallbackHandler {
.and_then(f)
}
fn protect_file_descriptor(&self, file_descriptor: RawFd) -> Result<(), CallbackError> {
fn protect(&self, socket: RawFd) -> io::Result<()> {
self.env(|mut env| {
call_method(
&mut env,
&self.callback_handler,
"protectFileDescriptor",
"(I)V",
&[JValue::Int(file_descriptor)],
&[JValue::Int(socket)],
)
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
@@ -357,18 +359,10 @@ fn connect(
.enable_all()
.build()?;
let sockets = Sockets::with_protect({
let callbacks = callbacks.clone();
move |fd| {
callbacks
.protect_file_descriptor(fd)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
});
let args = ConnectArgs {
url,
sockets,
tcp_socket_factory: Arc::new(protected_tcp_socket_factory(callbacks.clone())),
udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())),
private_key,
os_version_override: Some(os_version),
app_version: env!("CARGO_PKG_VERSION").to_string(),
@@ -523,3 +517,23 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
session.inner.set_tun(tun);
}
fn protected_tcp_socket_factory(
callbacks: CallbackHandler,
) -> impl SocketFactory<tokio::net::TcpSocket> {
move |addr| {
let socket = socket_factory::tcp(addr)?;
callbacks.protect(socket.as_raw_fd())?;
Ok(socket)
}
}
fn protected_udp_socket_factory(
callbacks: CallbackHandler,
) -> impl SocketFactory<tokio::net::UdpSocket> {
move |addr| {
let socket = socket_factory::udp(addr)?;
callbacks.protect(socket.as_raw_fd())?;
Ok(socket)
}
}

View File

@@ -16,6 +16,7 @@ ip_network = "0.4"
libc = "0.2"
secrecy = { workspace = true }
serde_json = "1"
socket-factory = { workspace = true }
swift-bridge = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
tracing = { workspace = true }

View File

@@ -5,7 +5,7 @@ mod make_writer;
use connlib_client_shared::{
callbacks::ResourceDescription, file_logger, keypair, Callbacks, ConnectArgs, Error, LoginUrl,
Session, Sockets, Tun, V4RouteList, V6RouteList,
Session, Tun, V4RouteList, V6RouteList,
};
use ip_network::{Ipv4Network, Ipv6Network};
use secrecy::SecretString;
@@ -194,7 +194,6 @@ impl WrappedSession {
let args = ConnectArgs {
url,
sockets: Sockets::new(),
private_key,
os_version_override,
app_version: env!("CARGO_PKG_VERSION").to_string(),
@@ -202,6 +201,8 @@ impl WrappedSession {
inner: Arc::new(callback_handler),
},
max_partition_time: Some(MAX_PARTITION_TIME),
tcp_socket_factory: Arc::new(socket_factory::tcp),
udp_socket_factory: Arc::new(socket_factory::udp),
};
let session = Session::connect(args, runtime.handle().clone());
let _enter = runtime.enter();

View File

@@ -17,6 +17,7 @@ ip_network = { version = "0.4", default-features = false }
phoenix-channel = { workspace = true }
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["std", "derive"] }
socket-factory = { workspace = true }
time = { version = "0.3.36", features = ["formatting"] }
tokio = { workspace = true, features = ["sync"] }
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }

View File

@@ -5,15 +5,17 @@ pub use connlib_shared::{
callbacks, keypair, Callbacks, Error, LoginUrl, LoginUrlError, StaticSecret,
};
pub use eventloop::Eventloop;
pub use firezone_tunnel::{Sockets, Tun};
pub use firezone_tunnel::Tun;
pub use tracing_appender::non_blocking::WorkerGuard;
use backoff::ExponentialBackoffBuilder;
use connlib_shared::get_user_agent;
use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use socket_factory::SocketFactory;
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::UnboundedReceiver;
@@ -39,7 +41,8 @@ pub struct Session {
/// Arguments for `connect`, since Clippy said 8 args is too many
pub struct ConnectArgs<CB> {
pub url: LoginUrl,
pub sockets: Sockets,
pub tcp_socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
pub udp_socket_factory: Arc<dyn SocketFactory<tokio::net::UdpSocket>>,
pub private_key: StaticSecret,
pub os_version_override: Option<String>,
pub app_version: String,
@@ -120,11 +123,12 @@ where
{
let ConnectArgs {
url,
sockets,
private_key,
os_version_override,
app_version,
callbacks,
udp_socket_factory,
tcp_socket_factory,
max_partition_time,
} = args;
@@ -139,7 +143,8 @@ where
let tunnel = ClientTunnel::new(
private_key,
sockets,
tcp_socket_factory.clone(),
udp_socket_factory,
callbacks,
HashMap::from([(url.host().to_string(), addrs)]),
)?;
@@ -152,6 +157,7 @@ where
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(max_partition_time)
.build(),
tcp_socket_factory,
);
let mut eventloop = Eventloop::new(tunnel, portal, rx);
@@ -232,14 +238,18 @@ mod tests {
#[cfg(any(target_os = "windows", target_os = "linux"))]
async fn device_common() {
use firezone_tunnel::Tun;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
let (private_key, _public_key) = connlib_shared::keypair();
let sockets = crate::Sockets::new();
let callbacks = Callbacks::default();
let mut tunnel =
firezone_tunnel::ClientTunnel::new(private_key, sockets, callbacks, HashMap::new())
.unwrap();
let mut tunnel = firezone_tunnel::ClientTunnel::new(
private_key,
Arc::new(socket_factory::tcp),
Arc::new(socket_factory::udp),
callbacks,
HashMap::new(),
)
.unwrap();
let upstream_dns = vec![([192, 168, 1, 1], 53).into()];
let interface = connlib_shared::messages::Interface {
ipv4: [100, 71, 96, 96].into(),

View File

@@ -30,7 +30,8 @@ rangemap = "1.5.1"
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["derive", "std"] }
snownet = { workspace = true }
socket2 = { version = "0.5" }
socket-factory = { workspace = true }
socket2 = { workspace = true }
thiserror = { version = "1.0", default-features = false }
tokio = { workspace = true }
tracing = { workspace = true }

View File

@@ -5,22 +5,29 @@ use crate::{
};
use bytes::Bytes;
use connlib_shared::messages::DnsServer;
use futures::Future;
use futures_bounded::FuturesTupleSet;
use futures_util::FutureExt as _;
use hickory_proto::iocompat::AsyncIoTokioAsStd;
use hickory_proto::TokioTime;
use hickory_resolver::{
config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
TokioAsyncResolver,
name_server::{GenericConnector, RuntimeProvider},
AsyncResolver, TokioHandle,
};
use ip_packet::{IpPacket, MutableIpPacket};
use quinn_udp::Transmit;
use socket_factory::SocketFactory;
use std::{
collections::HashMap,
io,
net::IpAddr,
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::net::{TcpSocket, UdpSocket};
const DNS_QUERIES_QUEUE_SIZE: usize = 100;
@@ -32,9 +39,13 @@ pub struct Io {
device: Device,
/// The UDP sockets used to send & receive packets from the network.
sockets: Sockets,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
timeout: Option<Pin<Box<tokio::time::Sleep>>>,
upstream_dns_servers: HashMap<IpAddr, TokioAsyncResolver>,
upstream_dns_servers: HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>>,
forwarded_dns_queries: FuturesTupleSet<
Result<hickory_resolver::lookup::Lookup, hickory_resolver::error::ResolveError>,
DnsQuery<'static>,
@@ -58,13 +69,19 @@ impl Io {
/// Creates a new I/O abstraction
///
/// Must be called within a Tokio runtime context so we can bind the sockets.
pub fn new(mut sockets: Sockets) -> io::Result<Self> {
sockets.rebind()?; // Bind sockets on startup. Must happen within a tokio runtime context.
pub fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> io::Result<Self> {
let mut sockets = Sockets::default();
sockets.rebind(udp_socket_factory.as_ref())?; // Bind sockets on startup. Must happen within a tokio runtime context.
Ok(Self {
device: Device::new(),
timeout: None,
sockets,
tcp_socket_factory,
udp_socket_factory,
upstream_dns_servers: HashMap::default(),
forwarded_dns_queries: FuturesTupleSet::new(
Duration::from_secs(60),
@@ -107,8 +124,10 @@ impl Io {
&mut self.device
}
pub fn sockets_mut(&mut self) -> &mut Sockets {
&mut self.sockets
pub fn rebind_sockets(&mut self) -> io::Result<()> {
self.sockets.rebind(self.udp_socket_factory.as_ref())?;
Ok(())
}
pub fn set_upstream_dns_servers(
@@ -119,7 +138,13 @@ impl Io {
self.forwarded_dns_queries =
FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE);
self.upstream_dns_servers = create_resolvers(dns_servers);
self.upstream_dns_servers = create_resolvers(
dns_servers,
TokioRuntimeProvider::new(
self.tcp_socket_factory.clone(),
self.udp_socket_factory.clone(),
),
);
}
pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) -> Result<(), DnsQueryError> {
@@ -186,9 +211,65 @@ pub enum DnsQueryError {
TooManyQueries,
}
/// Identical to [`TokioRuntimeProvider`](hickory_resolver::name_server::TokioRuntimeProvider) but using our own [`SocketFactory`].
#[derive(Clone)]
struct TokioRuntimeProvider {
handle: TokioHandle,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
}
impl TokioRuntimeProvider {
fn new(
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> TokioRuntimeProvider {
Self {
handle: Default::default(),
tcp_socket_factory,
udp_socket_factory,
}
}
}
impl RuntimeProvider for TokioRuntimeProvider {
type Handle = TokioHandle;
type Timer = TokioTime;
type Udp = UdpSocket;
type Tcp = AsyncIoTokioAsStd<tokio::net::TcpStream>;
fn create_handle(&self) -> Self::Handle {
self.handle.clone()
}
fn connect_tcp(
&self,
server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
let socket = (self.tcp_socket_factory)(&server_addr);
Box::pin(async move {
let socket = socket?;
let stream = socket.connect(server_addr).await?;
Ok(AsyncIoTokioAsStd(stream))
})
}
fn bind_udp(
&self,
local_addr: SocketAddr,
_server_addr: SocketAddr,
) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
let socket = (self.udp_socket_factory)(&local_addr);
Box::pin(async move { socket })
}
}
fn create_resolvers(
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
) -> HashMap<IpAddr, TokioAsyncResolver> {
runtime_provider: TokioRuntimeProvider,
) -> HashMap<IpAddr, AsyncResolver<GenericConnector<TokioRuntimeProvider>>> {
dns_servers
.into_iter()
.map(|(sentinel, srv)| {
@@ -201,7 +282,11 @@ fn create_resolvers(
(
sentinel,
TokioAsyncResolver::tokio(resolver_config, resolver_opts),
AsyncResolver::new_with_conn(
resolver_config,
resolver_opts,
GenericConnector::new(runtime_provider.clone()),
),
)
})
.collect()

View File

@@ -14,6 +14,7 @@ use io::Io;
use std::{
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr},
sync::Arc,
task::{Context, Poll},
time::Instant,
};
@@ -21,7 +22,6 @@ use std::{
use bimap::BiMap;
pub use client::{ClientState, Request};
pub use gateway::GatewayState;
pub use sockets::Sockets;
use utils::turn;
mod client;
@@ -50,7 +50,7 @@ pub type ClientTunnel<CB> = Tunnel<CB, ClientState>;
/// [`Tunnel`] glues together connlib's [`Io`] component and the respective (pure) state of a client or gateway.
///
/// Most of connlib's functionality is implemented as a pure state machine in [`ClientState`] and [`GatewayState`].
/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`] or time and pass it to the respective state.
/// The only job of [`Tunnel`] is to take input from the TUN [`Device`](crate::device_channel::Device), [`Sockets`](crate::sockets::Sockets) or time and pass it to the respective state.
pub struct Tunnel<CB: Callbacks, TRoleState> {
pub callbacks: CB,
@@ -77,12 +77,13 @@ where
{
pub fn new(
private_key: StaticSecret,
sockets: Sockets,
tcp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::TcpSocket>>,
udp_socket_factory: Arc<dyn socket_factory::SocketFactory<tokio::net::UdpSocket>>,
callbacks: CB,
known_hosts: HashMap<String, Vec<IpAddr>>,
) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(sockets)?,
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
callbacks,
role_state: ClientState::new(private_key, known_hosts),
write_buf: Box::new([0u8; MTU + 16 + 20]),
@@ -94,7 +95,7 @@ where
pub fn reset(&mut self) -> std::io::Result<()> {
self.role_state.reset();
self.io.sockets_mut().rebind()?;
self.io.rebind_sockets()?;
Ok(())
}
@@ -178,13 +179,9 @@ impl<CB> GatewayTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(
private_key: StaticSecret,
sockets: Sockets,
callbacks: CB,
) -> std::io::Result<Self> {
pub fn new(private_key: StaticSecret, callbacks: CB) -> std::io::Result<Self> {
Ok(Self {
io: Io::new(sockets)?,
io: Io::new(Arc::new(socket_factory::tcp), Arc::new(socket_factory::udp))?,
callbacks,
role_state: GatewayState::new(private_key),
write_buf: Box::new([0u8; MTU + 20 + 16]),

View File

@@ -1,61 +1,28 @@
use core::slice;
use quinn_udp::{RecvMeta, UdpSockRef, UdpSocketState};
use socket2::{SockAddr, Type};
use socket_factory::SocketFactory;
use std::{
io::{self, IoSliceMut},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
task::{ready, Context, Poll},
};
use tokio::{io::Interest, net::UdpSocket};
use crate::Result;
pub struct Sockets {
#[derive(Default)]
pub(crate) struct Sockets {
socket_v4: Option<Socket>,
socket_v6: Option<Socket>,
#[cfg(unix)]
protect: Box<dyn Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static>,
}
impl Default for Sockets {
fn default() -> Self {
Self::new()
}
}
impl Sockets {
#[cfg(unix)]
pub fn with_protect(
protect: impl Fn(std::os::fd::RawFd) -> io::Result<()> + Send + 'static,
) -> Self {
Self {
socket_v4: None,
socket_v6: None,
#[cfg(unix)]
protect: Box::new(protect),
}
}
pub fn new() -> Self {
Self {
socket_v4: None,
socket_v6: None,
#[cfg(unix)]
protect: Box::new(|_| Ok(())),
}
}
pub fn can_handle(&self, addr: &SocketAddr) -> bool {
match addr {
SocketAddr::V4(_) => self.socket_v4.is_some(),
SocketAddr::V6(_) => self.socket_v6.is_some(),
}
}
pub fn rebind(&mut self) -> io::Result<()> {
let socket_v4 = Socket::ip4();
let socket_v6 = Socket::ip6();
pub fn rebind(
&mut self,
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
) -> io::Result<()> {
let socket_v4 = Socket::ip4(socket_factory);
let socket_v6 = Socket::ip6(socket_factory);
match (socket_v4.as_ref(), socket_v6.as_ref()) {
(Err(e), Ok(_)) => {
@@ -76,19 +43,6 @@ impl Sockets {
_ => (),
}
#[cfg(unix)]
{
use std::os::fd::AsRawFd;
if let Ok(fd) = socket_v4.as_ref().map(|s| s.socket.as_raw_fd()) {
(self.protect)(fd)?;
}
if let Ok(fd) = socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) {
(self.protect)(fd)?;
}
}
self.socket_v4 = socket_v4.ok();
self.socket_v6 = socket_v6.ok();
@@ -216,28 +170,33 @@ struct Socket {
}
impl Socket {
fn ip4() -> Result<Socket> {
let socket = make_socket(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))?;
fn ip(
socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>,
addr: &SocketAddr,
) -> Result<Socket> {
let socket = socket_factory(addr)?;
let port = socket.local_addr()?.port();
Ok(Socket {
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
port,
socket: tokio::net::UdpSocket::from_std(socket)?,
socket,
buffered_transmits: Vec::new(),
})
}
fn ip6() -> Result<Socket> {
let socket = make_socket(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))?;
let port = socket.local_addr()?.port();
fn ip4(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
Self::ip(
socket_factory,
&SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)),
)
}
Ok(Socket {
state: UdpSocketState::new(UdpSockRef::from(&socket))?,
port,
socket: tokio::net::UdpSocket::from_std(socket)?,
buffered_transmits: Vec::new(),
})
fn ip6(socket_factory: &dyn SocketFactory<tokio::net::UdpSocket>) -> Result<Socket> {
Self::ip(
socket_factory,
&SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)),
)
}
#[allow(clippy::type_complexity)]
@@ -332,25 +291,3 @@ impl Socket {
);
}
}
fn make_socket(addr: impl Into<SocketAddr>) -> Result<std::net::UdpSocket> {
let addr: SockAddr = addr.into().into();
let socket = socket2::Socket::new(addr.domain(), Type::DGRAM, None)?;
#[cfg(target_os = "linux")]
{
const FIREZONE_MARK: u32 = 0xfd002021; // Keep this synced with `TunDeviceManager` until #5797.
socket.set_mark(FIREZONE_MARK)?;
}
// Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag
if addr.is_ipv6() {
socket.set_only_v6(true)?;
}
socket.set_nonblocking(true)?;
socket.bind(&addr)?;
Ok(socket.into())
}

View File

@@ -27,6 +27,7 @@ phoenix-channel = { workspace = true }
secrecy = { workspace = true }
serde = { version = "1.0", default-features = false, features = ["std", "derive"] }
snownet = { workspace = true }
socket-factory = { workspace = true }
static_assertions = "1.1.0"
tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread", "fs", "signal"] }
tokio-tungstenite = { version = "0.21", default-features = false, features = ["connect", "handshake", "rustls-tls-webpki-roots"] }

View File

@@ -6,7 +6,8 @@ use connlib_shared::{
get_user_agent, keypair, messages::Interface, Callbacks, LoginUrl, StaticSecret,
};
use firezone_bin_shared::{setup_global_subscriber, CommonArgs, TunDeviceManager};
use firezone_tunnel::{GatewayTunnel, Sockets, Tun};
use firezone_tunnel::{GatewayTunnel, Tun};
use futures::channel::mpsc;
use futures::{future, StreamExt, TryFutureExt};
use ip_network::{Ipv4Network, Ipv6Network};
@@ -15,6 +16,7 @@ use secrecy::{Secret, SecretString};
use std::convert::Infallible;
use std::path::Path;
use std::pin::pin;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::signal::ctrl_c;
use tracing_subscriber::layer;
@@ -100,7 +102,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
}
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new(), CallbackHandler)?;
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
let portal = PhoenixChannel::connect(
Secret::new(login),
get_user_agent(None, env!("CARGO_PKG_VERSION")),
@@ -109,6 +111,7 @@ async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(None)
.build(),
Arc::new(socket_factory::tcp),
);
let (sender, receiver) = mpsc::channel::<Interface>(10);

View File

@@ -20,6 +20,7 @@ ip_network = { version = "0.4", default-features = false }
secrecy = { workspace = true }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.117"
socket-factory = { workspace = true }
thiserror = { version = "1.0", default-features = false }
# This actually relies on many other features in Tokio, so this will probably
# fail to build outside the workspace. <https://github.com/firezone/firezone/pull/4328#discussion_r1540342142>
@@ -40,7 +41,7 @@ mutants = "0.0.3" # Needed to mark functions as exempt from `cargo-mutants` test
[target.'cfg(target_os = "linux")'.dependencies]
dirs = "5.0.1"
libc = "0.2.150"
nix = { version = "0.28.0", features = ["fs", "user"] }
nix = { version = "0.28.0", features = ["fs", "user", "socket"] }
resolv-conf = "0.7.0"
rtnetlink = { workspace = true }
sd-notify = "0.4.1" # This is a pure Rust re-implementation, so it isn't vulnerable to CVE-2024-3094

View File

@@ -6,13 +6,13 @@ use crate::{
};
use anyhow::{Context as _, Result};
use clap::Parser;
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets};
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session};
use futures::{
future::poll_fn,
task::{Context, Poll},
Future as _, SinkExt as _, Stream as _,
};
use std::{net::IpAddr, path::PathBuf, pin::pin, time::Duration};
use std::{net::IpAddr, path::PathBuf, pin::pin, sync::Arc, time::Duration};
use tokio::{sync::mpsc, time::Instant};
use tracing::subscriber::set_global_default;
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry};
@@ -341,7 +341,8 @@ impl Handler {
self.last_connlib_start_instant = Some(Instant::now());
let args = ConnectArgs {
url,
sockets: Sockets::new(),
tcp_socket_factory: Arc::new(crate::tcp_socket_factory),
udp_socket_factory: Arc::new(crate::udp_socket_factory),
private_key,
os_version_override: None,
app_version: env!("CARGO_PKG_VERSION").to_string(),

View File

@@ -20,6 +20,8 @@ use tracing::subscriber::set_global_default;
use tracing_subscriber::{fmt, layer::SubscriberExt as _, EnvFilter, Layer as _, Registry};
use platform::default_token_path;
use platform::tcp_socket_factory;
use platform::udp_socket_factory;
/// Generate a persistent device ID, stores it to disk, and reads it back.
pub(crate) mod device_id;

View File

@@ -2,13 +2,31 @@
use super::TOKEN_ENV_KEY;
use anyhow::{bail, Result};
use std::path::{Path, PathBuf};
use firezone_bin_shared::FIREZONE_MARK;
use nix::sys::socket::{setsockopt, sockopt};
use std::{
io,
net::SocketAddr,
path::{Path, PathBuf},
};
// The Client currently must run as root to control DNS
// Root group and user are used to check file ownership on the token
const ROOT_GROUP: u32 = 0;
const ROOT_USER: u32 = 0;
pub(crate) fn tcp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::TcpSocket> {
let socket = socket_factory::tcp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
}
pub(crate) fn udp_socket_factory(socket_addr: &SocketAddr) -> io::Result<tokio::net::UdpSocket> {
let socket = socket_factory::udp(socket_addr)?;
setsockopt(&socket, sockopt::Mark, &FIREZONE_MARK)?;
Ok(socket)
}
pub(crate) fn default_token_path() -> PathBuf {
PathBuf::from("/etc")
.join(connlib_shared::BUNDLE_ID)

View File

@@ -6,13 +6,14 @@ use crate::{
};
use anyhow::{anyhow, Context as _, Result};
use clap::Parser;
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session, Sockets};
use connlib_client_shared::{file_logger, keypair, ConnectArgs, LoginUrl, Session};
use firezone_bin_shared::{setup_global_subscriber, TunDeviceManager};
use futures::{FutureExt as _, StreamExt as _};
use secrecy::SecretString;
use std::{
path::{Path, PathBuf},
pin::pin,
sync::Arc,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
@@ -156,7 +157,8 @@ pub fn run_only_headless_client() -> Result<()> {
platform::setup_before_connlib()?;
let args = ConnectArgs {
url,
sockets: Sockets::new(),
udp_socket_factory: Arc::new(crate::udp_socket_factory),
tcp_socket_factory: Arc::new(crate::tcp_socket_factory),
private_key,
os_version_override: None,
app_version: env!("CARGO_PKG_VERSION").to_string(),

View File

@@ -7,6 +7,9 @@
use anyhow::Result;
use std::path::{Path, PathBuf};
pub(crate) use socket_factory::tcp as tcp_socket_factory;
pub(crate) use socket_factory::udp as udp_socket_factory;
#[path = "windows/wintun_install.rs"]
mod wintun_install;

View File

@@ -15,6 +15,7 @@ secrecy = { workspace = true }
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.117"
sha2 = "0.10.8"
socket-factory = { workspace = true }
thiserror = "1.0.61"
tokio = { workspace = true, features = ["net", "time"] }
tokio-tungstenite = { workspace = true, features = ["rustls-tls-webpki-roots"] }

View File

@@ -3,6 +3,7 @@ mod login_url;
use std::collections::{HashSet, VecDeque};
use std::mem;
use std::net::SocketAddr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::{fmt, future, marker::PhantomData};
@@ -16,15 +17,16 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat};
use rand_core::{OsRng, RngCore};
use secrecy::{ExposeSecret as _, Secret};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use socket_factory::SocketFactory;
use std::task::{Context, Poll, Waker};
use tokio::net::TcpStream;
use tokio_tungstenite::connect_async_with_config;
use tokio_tungstenite::client_async_tls;
use tokio_tungstenite::tungstenite::http::StatusCode;
use tokio_tungstenite::{
connect_async,
tungstenite::{handshake::client::Request, Message},
MaybeTlsStream, WebSocketStream,
};
use url::{Host, Url};
pub use login_url::{LoginUrl, LoginUrlError};
@@ -33,6 +35,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
waker: Option<Waker>,
pending_messages: VecDeque<String>,
next_request_id: Arc<AtomicU64>,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
heartbeat: Heartbeat,
@@ -59,17 +62,70 @@ enum State {
}
impl State {
fn connect(url: Secret<LoginUrl>, user_agent: String) -> Self {
Self::Connecting(Box::pin(async move {
let (stream, _) = connect_async_with_config(make_request(url, user_agent), None, true)
.await
.map_err(InternalError::WebSocket)?;
Ok(stream)
}))
fn connect(
url: Secret<LoginUrl>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
) -> Self {
Self::Connecting(create_and_connect_websocket(url, user_agent, socket_factory).boxed())
}
}
async fn create_and_connect_websocket(
url: Secret<LoginUrl>,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
let (stream, _) = client_async_tls(make_request(url, user_agent), socket)
.await
.map_err(InternalError::WebSocket)?;
Ok(stream)
}
async fn make_socket(
url: &Url,
socket_factory: &dyn SocketFactory<tokio::net::TcpSocket>,
) -> Result<TcpStream, InternalError> {
let port = url
.port_or_known_default()
.expect("scheme to be http, https, ws or wss");
let addrs: Vec<SocketAddr> = match url.host().ok_or(InternalError::InvalidUrl)? {
Host::Domain(n) => tokio::net::lookup_host((n, port))
.await
.map_err(|_| InternalError::InvalidUrl)?
.collect(),
Host::Ipv6(ip) => {
vec![(ip, port).into()]
}
Host::Ipv4(ip) => {
vec![(ip, port).into()]
}
};
let mut last_error = None;
for addr in addrs {
let Ok(socket) = socket_factory(&addr) else {
continue;
};
match socket.connect(addr).await {
Ok(socket) => return Ok(socket),
Err(e) => {
last_error = Some(e);
}
}
}
let Some(err) = last_error else {
return Err(InternalError::InvalidUrl);
};
Err(InternalError::SocketConnection(err))
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("client error: {0}")]
@@ -99,6 +155,8 @@ enum InternalError {
MissedHeartbeat,
CloseMessage,
StreamClosed,
InvalidUrl,
SocketConnection(std::io::Error),
}
impl fmt::Display for InternalError {
@@ -119,6 +177,8 @@ impl fmt::Display for InternalError {
InternalError::MissedHeartbeat => write!(f, "portal did not respond to our heartbeat"),
InternalError::CloseMessage => write!(f, "portal closed the websocket connection"),
InternalError::StreamClosed => write!(f, "websocket stream was closed"),
InternalError::InvalidUrl => write!(f, "failed to resolve url"),
InternalError::SocketConnection(e) => write!(f, "failed to connect socket: {e}"),
}
}
}
@@ -161,14 +221,13 @@ where
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
///
/// Once the connection is established,
pub fn connect(
url: Secret<LoginUrl>,
user_agent: String,
login: &'static str,
init_req: TInitReq,
reconnect_backoff: ExponentialBackoff,
socket_factory: Arc<dyn SocketFactory<tokio::net::TcpSocket>>,
) -> Self {
let next_request_id = Arc::new(AtomicU64::new(0));
@@ -178,7 +237,8 @@ where
reconnect_backoff,
url: url.clone(),
user_agent: user_agent.clone(),
state: State::connect(url, user_agent),
state: State::connect(url, user_agent, socket_factory.clone()),
socket_factory,
waker: None,
pending_messages: Default::default(),
_phantom: PhantomData,
@@ -220,7 +280,7 @@ where
// 2. Set state to `Connecting` without a timer.
let url = self.url.clone();
let user_agent = self.user_agent.clone();
self.state = State::connect(url, user_agent);
self.state = State::connect(url, user_agent, self.socket_factory.clone());
// 3. In case we were already re-connecting, we need to wake the suspended task.
if let Some(waker) = self.waker.take() {
@@ -293,18 +353,16 @@ where
let secret_url = self.url.clone();
let user_agent = self.user_agent.clone();
let socket_factory = self.socket_factory.clone();
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
self.state = State::Connecting(Box::pin(async move {
tokio::time::sleep(backoff).await;
let (stream, _) = connect_async(make_request(secret_url, user_agent))
create_and_connect_websocket(secret_url, user_agent, socket_factory)
.await
.map_err(InternalError::WebSocket)?;
Ok(stream)
}));
continue;
}
Poll::Pending => {

View File

@@ -26,7 +26,8 @@ rand = "0.8.5"
secrecy = { workspace = true }
serde = { version = "1.0.203", features = ["derive"] }
sha2 = "0.10.8"
socket2 = "0.5.7"
socket-factory = { workspace = true }
socket2 = { workspace = true }
stun_codec = "0.3.4"
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "net", "time", "signal"] }
tracing = { workspace = true, features = ["log"] }

View File

@@ -146,6 +146,7 @@ async fn main() -> Result<()> {
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(Some(MAX_PARTITION_TIME))
.build(),
Arc::new(socket_factory::tcp),
))
} else {
tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode");

View File

@@ -0,0 +1,8 @@
[package]
name = "socket-factory"
version = "0.1.0"
edition = "2021"
[dependencies]
socket2 = { workspace = true }
tokio = { version = "1.38", features = ["net"] }

View File

@@ -0,0 +1,35 @@
use std::net::SocketAddr;
use socket2::SockAddr;
pub trait SocketFactory<S>: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static {}
impl<F, S> SocketFactory<S> for F where
F: Fn(&SocketAddr) -> std::io::Result<S> + Send + Sync + 'static
{
}
pub fn tcp(addr: &SocketAddr) -> std::io::Result<tokio::net::TcpSocket> {
let socket = match addr {
SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4()?,
SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6()?,
};
socket.set_nodelay(true)?;
Ok(socket)
}
pub fn udp(addr: &SocketAddr) -> std::io::Result<tokio::net::UdpSocket> {
let addr: SockAddr = (*addr).into();
let socket = socket2::Socket::new(addr.domain(), socket2::Type::DGRAM, None)?;
// Note: for AF_INET sockets IPV6_V6ONLY is not a valid flag
if addr.is_ipv6() {
socket.set_only_v6(true)?;
}
socket.set_nonblocking(true)?;
socket.bind(&addr)?;
std::net::UdpSocket::from(socket).try_into()
}