diff --git a/kotlin/android/app/build.gradle.kts b/kotlin/android/app/build.gradle.kts index 589c13a3f..9a27d9e35 100644 --- a/kotlin/android/app/build.gradle.kts +++ b/kotlin/android/app/build.gradle.kts @@ -1,7 +1,6 @@ plugins { id("com.android.application") id("kotlin-android") - id("kotlin-kapt") id("dagger.hilt.android.plugin") id("kotlin-parcelize") id("androidx.navigation.safeargs") @@ -9,6 +8,7 @@ plugins { id("com.google.gms.google-services") id("com.google.firebase.crashlytics") id("com.diffplug.spotless") version "6.22.0" + id("kotlin-kapt") } spotless { diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/TunnelService.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/TunnelService.kt index 43233ce26..d07f05ac6 100644 --- a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/TunnelService.kt +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/TunnelService.kt @@ -20,6 +20,7 @@ import dev.firezone.android.core.domain.preference.GetConfigUseCase import dev.firezone.android.core.presentation.MainActivity import dev.firezone.android.tunnel.callback.ConnlibCallback import dev.firezone.android.tunnel.data.TunnelRepository +import dev.firezone.android.tunnel.model.Cidr import dev.firezone.android.tunnel.model.Resource import dev.firezone.android.tunnel.model.Tunnel import dev.firezone.android.tunnel.model.TunnelConfig @@ -99,16 +100,26 @@ class TunnelService : VpnService() { return true } - override fun onAddRoute(cidrAddress: String) { - Log.d(TAG, "onAddRoute: $cidrAddress") - - tunnelRepository.addRoute(cidrAddress) + override fun onAddRoute( + addr: String, + prefix: Int, + ): Int { + Log.d(TAG, "onAddRoute: $addr/$prefix") + val route = Cidr(addr, prefix) + tunnelRepository.addRoute(route) + val fd = buildVpnService().establish()?.detachFd() ?: -1 + protect(fd) + return fd } - override fun onRemoveRoute(cidrAddress: String) { - Log.d(TAG, "onRemoveRoute: $cidrAddress") + override fun onRemoveRoute( + addr: String, + prefix: Int, + ) { + Log.d(TAG, "onRemoveRoute: $addr/$prefix") + val cidr = Cidr(addr, prefix) - tunnelRepository.removeRoute(cidrAddress) + tunnelRepository.removeRoute(cidr) } override fun getSystemDefaultResolvers(): String { @@ -231,16 +242,9 @@ class TunnelService : VpnService() { addDnsServer(tunnel.config.dnsAddress) - /*tunnel.routes.forEach { - addRoute(it, 32) - }*/ - - // TODO: These are the staging Resources. Remove these in favor of the onUpdateResources callback. - addRoute("100.100.111.1", 32) - addRoute("172.31.82.179", 32) - addRoute("172.31.83.10", 32) - addRoute("172.31.92.238", 32) - addRoute("172.31.93.123", 32) + tunnel.routes.forEach { + addRoute(it.address, it.prefix) + } setSession(SESSION_NAME) diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/callback/ConnlibCallback.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/callback/ConnlibCallback.kt index c1633144b..244f14817 100644 --- a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/callback/ConnlibCallback.kt +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/callback/ConnlibCallback.kt @@ -11,9 +11,15 @@ interface ConnlibCallback { fun onTunnelReady(): Boolean - fun onAddRoute(cidrAddress: String) + fun onAddRoute( + cidrAddress: String, + prefix: Int, + ): Int - fun onRemoveRoute(cidrAddress: String) + fun onRemoveRoute( + addr: String, + prefix: Int, + ) fun onUpdateResources(resourceListJSON: String) diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepository.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepository.kt index 8b6425933..1f3ca1ffb 100644 --- a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepository.kt +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepository.kt @@ -2,6 +2,7 @@ package dev.firezone.android.tunnel.data import android.content.SharedPreferences +import dev.firezone.android.tunnel.model.Cidr import dev.firezone.android.tunnel.model.Resource import dev.firezone.android.tunnel.model.Tunnel import dev.firezone.android.tunnel.model.TunnelConfig @@ -21,11 +22,11 @@ interface TunnelRepository { fun getResources(): List - fun addRoute(route: String) + fun addRoute(route: Cidr) - fun removeRoute(route: String) + fun removeRoute(route: Cidr) - fun getRoutes(): List + fun getRoutes(): List fun clear() diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepositoryImpl.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepositoryImpl.kt index 9f1f83e34..398378f18 100644 --- a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepositoryImpl.kt +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/data/TunnelRepositoryImpl.kt @@ -8,6 +8,7 @@ import dev.firezone.android.tunnel.data.TunnelRepository.Companion.CONFIG_KEY import dev.firezone.android.tunnel.data.TunnelRepository.Companion.RESOURCES_KEY import dev.firezone.android.tunnel.data.TunnelRepository.Companion.ROUTES_KEY import dev.firezone.android.tunnel.data.TunnelRepository.Companion.STATE_KEY +import dev.firezone.android.tunnel.model.Cidr import dev.firezone.android.tunnel.model.Resource import dev.firezone.android.tunnel.model.Tunnel import dev.firezone.android.tunnel.model.TunnelConfig @@ -83,30 +84,30 @@ class TunnelRepositoryImpl } } - override fun addRoute(route: String) { + override fun addRoute(route: Cidr) { synchronized(lock) { getRoutes().toMutableList().run { add(route) - val json = moshi.adapter>().toJson(this) + val json = moshi.adapter>().toJson(this) sharedPreferences.edit().putString(ROUTES_KEY, json).apply() } } } - override fun removeRoute(route: String) { + override fun removeRoute(route: Cidr) { synchronized(lock) { getRoutes().toMutableList().run { remove(route) - val json = moshi.adapter>().toJson(this) + val json = moshi.adapter>().toJson(this) sharedPreferences.edit().putString(ROUTES_KEY, json).apply() } } } - override fun getRoutes(): List = + override fun getRoutes(): List = synchronized(lock) { val json = sharedPreferences.getString(ROUTES_KEY, "[]") ?: "[]" - return moshi.adapter>().fromJson(json) ?: emptyList() + return moshi.adapter>().fromJson(json) ?: emptyList() } override fun clear() { diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Cidr.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Cidr.kt new file mode 100644 index 000000000..28c1716b5 --- /dev/null +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Cidr.kt @@ -0,0 +1,14 @@ +/* Licensed under Apache 2.0 (C) 2023 Firezone, Inc. */ +package dev.firezone.android.tunnel.model + +import android.os.Parcelable +import com.squareup.moshi.JsonClass +import kotlinx.parcelize.Parcelize + +@JsonClass(generateAdapter = true) +@Parcelize +data class Cidr( + // TODO: Not convinced of using String to store address, we can make a moshi InetAddress adapter + val address: String, + val prefix: Int, +) : Parcelable diff --git a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Tunnel.kt b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Tunnel.kt index beaeda170..216cc8951 100644 --- a/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Tunnel.kt +++ b/kotlin/android/app/src/main/java/dev/firezone/android/tunnel/model/Tunnel.kt @@ -10,7 +10,7 @@ import kotlinx.parcelize.Parcelize data class Tunnel( val config: TunnelConfig = TunnelConfig(), var state: State = State.Down, - val routes: List = emptyList(), + val routes: List = emptyList(), val resources: List = emptyList(), ) : Parcelable { enum class State { diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 6abc259f4..b9d935e21 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -122,8 +122,8 @@ fn init_logging(log_dir: &Path, log_filter: String) -> file_logger::Handle { .expect("Logging guard should never be initialized twice"); let _ = tracing_subscriber::registry() - .with(file_layer.with_filter(EnvFilter::new(log_filter))) - .with(android_layer()) + .with(file_layer.with_filter(EnvFilter::new(log_filter.clone()))) + .with(android_layer().with_filter(EnvFilter::new(log_filter.clone()))) .try_init(); handle @@ -138,7 +138,7 @@ impl Callbacks for CallbackHandler { tunnel_address_v6: Ipv6Addr, dns_address: Ipv4Addr, dns_fallback_strategy: String, - ) -> Result { + ) -> Result, Self::Error> { self.env(|mut env| { let tunnel_address_v4 = env.new_string(tunnel_address_v4.to_string()) @@ -179,6 +179,7 @@ impl Callbacks for CallbackHandler { ], ) .and_then(|val| val.i()) + .map(Some) .map_err(|source| CallbackError::CallMethodFailed { name, source }) }) } @@ -195,27 +196,31 @@ impl Callbacks for CallbackHandler { }) } - fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> { + fn on_add_route(&self, route: IpNetwork) -> Result, Self::Error> { self.env(|mut env| { - let route = env.new_string(route.to_string()).map_err(|source| { - CallbackError::NewStringFailed { - name: "route", + let ip = env + .new_string(route.network_address().to_string()) + .map_err(|source| CallbackError::NewStringFailed { + name: "route_ip", source, - } - })?; - call_method( - &mut env, + })?; + + let name = "onAddRoute"; + env.call_method( &self.callback_handler, - "onAddRoute", - "(Ljava/lang/String;)V", - &[JValue::from(&route)], + name, + "(Ljava/lang/String;I)I", + &[JValue::from(&ip), JValue::Int(route.netmask().into())], ) + .and_then(|val| val.i()) + .map(Some) + .map_err(|source| CallbackError::CallMethodFailed { name, source }) }) } fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> { self.env(|mut env| { - let route = env.new_string(route.to_string()).map_err(|source| { + let ip = env.new_string(route.to_string()).map_err(|source| { CallbackError::NewStringFailed { name: "route", source, @@ -225,8 +230,8 @@ impl Callbacks for CallbackHandler { &mut env, &self.callback_handler, "onRemoveRoute", - "(Ljava/lang/String;)V", - &[JValue::from(&route)], + "(Ljava/lang/String;I)V", + &[JValue::from(&ip), JValue::Int(route.netmask().into())], ) }) } @@ -301,7 +306,7 @@ impl Callbacks for CallbackHandler { fn throw(env: &mut JNIEnv, class: &str, msg: impl Into) { if let Err(err) = env.throw_new(class, msg) { // We can't panic, since unwinding across the FFI boundary is UB... - tracing::error!("failed to throw Java exception: {err}"); + tracing::error!(?err, "failed to throw Java exception"); } } diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 996f17818..e63319a0a 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -89,14 +89,14 @@ impl Callbacks for CallbackHandler { tunnel_address_v6: Ipv6Addr, dns_address: Ipv4Addr, dns_fallback_strategy: String, - ) -> Result { + ) -> Result, Self::Error> { self.inner.on_set_interface_config( tunnel_address_v4.to_string(), tunnel_address_v6.to_string(), dns_address.to_string(), dns_fallback_strategy.to_string(), ); - Ok(-1) + Ok(None) } fn on_tunnel_ready(&self) -> Result<(), Self::Error> { @@ -104,9 +104,9 @@ impl Callbacks for CallbackHandler { Ok(()) } - fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> { + fn on_add_route(&self, route: IpNetwork) -> Result, Self::Error> { self.inner.on_add_route(route.to_string()); - Ok(()) + Ok(None) } fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> { diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 17bf4a7ea..094b88a51 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -78,7 +78,11 @@ where runtime_stopper: tx.clone(), callbacks, }; + // In android we get an stack-overflow due to tokio + // taking too much of the stack-space: + // See: https://github.com/firezone/firezone/issues/2227 let runtime = tokio::runtime::Builder::new_multi_thread() + .thread_stack_size(3 * 1024 * 1024) .enable_all() .build()?; { diff --git a/rust/connlib/shared/src/callbacks.rs b/rust/connlib/shared/src/callbacks.rs index 5a8d681cc..4e619ce17 100644 --- a/rust/connlib/shared/src/callbacks.rs +++ b/rust/connlib/shared/src/callbacks.rs @@ -14,14 +14,17 @@ pub trait Callbacks: Clone + Send + Sync { type Error: Debug + Display + Error; /// Called when the tunnel address is set. + /// + /// This should return a new `fd` if there is one. + /// (Only happens on android for now) fn on_set_interface_config( &self, _: Ipv4Addr, _: Ipv6Addr, _: Ipv4Addr, _: String, - ) -> Result { - Ok(-1) + ) -> Result, Self::Error> { + Ok(None) } /// Called when the tunnel is connected. @@ -31,8 +34,11 @@ pub trait Callbacks: Clone + Send + Sync { } /// Called when when a route is added. - fn on_add_route(&self, _: IpNetwork) -> Result<(), Self::Error> { - Ok(()) + /// + /// This should return a new `fd` if there is one. + /// (Only happens on android for now) + fn on_add_route(&self, _: IpNetwork) -> Result, Self::Error> { + Ok(None) } /// Called when when a route is removed. diff --git a/rust/connlib/shared/src/callbacks_error_facade.rs b/rust/connlib/shared/src/callbacks_error_facade.rs index e9d9b1882..c694af707 100644 --- a/rust/connlib/shared/src/callbacks_error_facade.rs +++ b/rust/connlib/shared/src/callbacks_error_facade.rs @@ -18,7 +18,7 @@ impl Callbacks for CallbackErrorFacade { tunnel_address_v6: Ipv6Addr, dns_address: Ipv4Addr, dns_fallback_strategy: String, - ) -> Result { + ) -> Result> { let result = self .0 .on_set_interface_config( @@ -29,7 +29,7 @@ impl Callbacks for CallbackErrorFacade { ) .map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string())); if let Err(err) = result.as_ref() { - tracing::error!("{err}"); + tracing::error!(?err); } result } @@ -40,18 +40,18 @@ impl Callbacks for CallbackErrorFacade { .on_tunnel_ready() .map_err(|err| Error::OnTunnelReadyFailed(err.to_string())); if let Err(err) = result.as_ref() { - tracing::error!("{err}"); + tracing::error!(?err); } result } - fn on_add_route(&self, route: IpNetwork) -> Result<()> { + fn on_add_route(&self, route: IpNetwork) -> Result> { let result = self .0 .on_add_route(route) .map_err(|err| Error::OnAddRouteFailed(err.to_string())); if let Err(err) = result.as_ref() { - tracing::error!("{err}"); + tracing::error!(?err); } result } @@ -62,7 +62,7 @@ impl Callbacks for CallbackErrorFacade { .on_remove_route(route) .map_err(|err| Error::OnRemoveRouteFailed(err.to_string())); if let Err(err) = result.as_ref() { - tracing::error!("{err}"); + tracing::error!(?err); } result } @@ -73,14 +73,14 @@ impl Callbacks for CallbackErrorFacade { .on_update_resources(resource_list) .map_err(|err| Error::OnUpdateResourcesFailed(err.to_string())); if let Err(err) = result.as_ref() { - tracing::error!("{err}"); + tracing::error!(?err); } result } fn on_disconnect(&self, error: Option<&Error>) -> Result<()> { if let Err(err) = self.0.on_disconnect(error) { - tracing::error!("`on_disconnect` failed: {err}"); + tracing::error!(?err, "`on_disconnect` failed"); } // There's nothing we can really do if `on_disconnect` fails. Ok(()) @@ -88,7 +88,7 @@ impl Callbacks for CallbackErrorFacade { fn on_error(&self, error: &Error) -> Result<()> { if let Err(err) = self.0.on_error(error) { - tracing::error!("`on_error` failed: {err}"); + tracing::error!(?err, "`on_error` failed"); } // There's nothing we really want to do if `on_error` fails. Ok(()) diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index d461aa71f..d0a8ab9ec 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -96,6 +96,9 @@ pub enum ConnlibError { /// No iface found #[error("No iface found")] NoIface, + /// Expected file descriptor and none was found + #[error("No filedescriptor")] + NoFd, /// No MTU found #[error("No MTU found")] NoMtu, @@ -120,6 +123,9 @@ pub enum ConnlibError { /// Invalid source address for peer #[error("Invalid source address")] InvalidSource, + /// Any parse error + #[error("parse error")] + ParseError, } impl ConnlibError { diff --git a/rust/connlib/tunnel/src/control_protocol.rs b/rust/connlib/tunnel/src/control_protocol.rs index 797ce02c5..027bab19e 100644 --- a/rust/connlib/tunnel/src/control_protocol.rs +++ b/rust/connlib/tunnel/src/control_protocol.rs @@ -145,12 +145,8 @@ where }) }); - let Some(device_io) = self.device_io.read().clone() else { - return Err(Error::NoIface); - }; - let tunnel = Arc::clone(self); - tokio::spawn(async move { tunnel.peer_handler(peer, device_io).await }); + tokio::spawn(async move { tunnel.start_peer_handler(peer).await }); Ok(()) } diff --git a/rust/connlib/tunnel/src/control_protocol/gateway.rs b/rust/connlib/tunnel/src/control_protocol/gateway.rs index afe898442..1bc781a02 100644 --- a/rust/connlib/tunnel/src/control_protocol/gateway.rs +++ b/rust/connlib/tunnel/src/control_protocol/gateway.rs @@ -96,12 +96,13 @@ where tracing::trace!("new_data_channel_open"); Box::pin(async move { { - let Some(iface_config) = tunnel.iface_config.read().clone() else { + let Some(device) = tunnel.device.read().await.clone() else { let e = Error::NoIface; tracing::error!(err = ?e, "channel_open"); let _ = tunnel.callbacks().on_error(&e); return; }; + let iface_config = device.config; for &ip in &peer.ips { if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await { diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs index e3598e6f6..3de06cf94 100644 --- a/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs +++ b/rust/connlib/tunnel/src/device_channel/device_channel_unix.rs @@ -9,6 +9,8 @@ use tokio::io::{unix::AsyncFd, Interest}; use tun::{IfaceDevice, IfaceStream}; +use crate::Device; + mod tun; pub(crate) struct IfaceConfig { @@ -53,23 +55,32 @@ impl IfaceConfig { &self, route: IpNetwork, callbacks: &CallbackErrorFacade, - ) -> Result<()> { - self.iface.add_route(route, callbacks).await + ) -> Result> { + let Some((iface, stream)) = self.iface.add_route(route, callbacks).await? else { + return Ok(None); + }; + let io = DeviceIo(stream); + let mtu = iface.mtu().await?; + let config = Arc::new(IfaceConfig { + iface, + mtu: AtomicUsize::new(mtu), + }); + Ok(Some(Device { io, config })) } } pub(crate) async fn create_iface( config: &Interface, callbacks: &CallbackErrorFacade, -) -> Result<(IfaceConfig, DeviceIo)> { +) -> Result { let (iface, stream) = IfaceDevice::new(config, callbacks).await?; iface.up().await?; - let device_io = DeviceIo(stream); + let io = DeviceIo(stream); let mtu = iface.mtu().await?; - let iface_config = IfaceConfig { + let config = Arc::new(IfaceConfig { iface, mtu: AtomicUsize::new(mtu), - }; + }); - Ok((iface_config, device_io)) + Ok(Device { config, io }) } diff --git a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs b/rust/connlib/tunnel/src/device_channel/device_channel_win.rs index d61eac66a..ecc3ab681 100644 --- a/rust/connlib/tunnel/src/device_channel/device_channel_win.rs +++ b/rust/connlib/tunnel/src/device_channel/device_channel_win.rs @@ -1,3 +1,4 @@ +use crate::Device; use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result}; use ip_network::IpNetwork; @@ -33,7 +34,7 @@ impl IfaceConfig { &self, _: IpNetwork, _: &CallbackErrorFacade, - ) -> Result<()> { + ) -> Result> { todo!() } } @@ -41,6 +42,6 @@ impl IfaceConfig { pub(crate) async fn create_iface( _: &Interface, _: &CallbackErrorFacade, -) -> Result<(IfaceConfig, DeviceIo)> { +) -> Result { todo!() } diff --git a/rust/connlib/tunnel/src/device_channel/tun/closeable.rs b/rust/connlib/tunnel/src/device_channel/tun/closeable.rs new file mode 100644 index 000000000..1107310e6 --- /dev/null +++ b/rust/connlib/tunnel/src/device_channel/tun/closeable.rs @@ -0,0 +1,35 @@ +use std::os::fd::{AsRawFd, RawFd}; +use std::sync::atomic::{AtomicBool, Ordering}; + +#[derive(Debug)] +pub(crate) struct Closeable { + closed: AtomicBool, + value: RawFd, +} + +impl AsRawFd for Closeable { + fn as_raw_fd(&self) -> RawFd { + self.value.as_raw_fd() + } +} + +impl Closeable { + pub(crate) fn new(fd: RawFd) -> Self { + Self { + closed: AtomicBool::new(false), + value: fd, + } + } + + pub(crate) fn with(&self, f: impl FnOnce(RawFd) -> U) -> std::io::Result { + if self.closed.load(Ordering::Acquire) { + return Err(std::io::Error::from_raw_os_error(9)); + } + + Ok(f(self.value)) + } + + pub(crate) fn close(&self) { + self.closed.store(true, Ordering::Release); + } +} diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs index 753269045..2d01e74df 100644 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs +++ b/rust/connlib/tunnel/src/device_channel/tun/tun_android.rs @@ -1,4 +1,5 @@ use crate::InterfaceConfig; +use closeable::Closeable; use connlib_shared::{CallbackErrorFacade, Callbacks, Error, Result, DNS_SENTINEL}; use ip_network::IpNetwork; use libc::{ @@ -13,6 +14,7 @@ use std::{ }; use tokio::io::unix::AsyncFd; +mod closeable; mod wrapped_socket; // Android doesn't support Split DNS. So we intercept all requests and forward // the non-Firezone name resolution requests to the upstream DNS resolver. @@ -50,24 +52,27 @@ pub(crate) struct IfaceDevice(Arc>); #[derive(Debug)] pub(crate) struct IfaceStream { - fd: RawFd, + fd: Closeable, } impl AsRawFd for IfaceStream { fn as_raw_fd(&self) -> RawFd { - self.fd + self.fd.as_raw_fd() } } impl Drop for IfaceStream { fn drop(&mut self) { - unsafe { close(self.fd) }; + unsafe { close(self.fd.as_raw_fd()) }; } } impl IfaceStream { fn write(&self, buf: &[u8]) -> std::io::Result { - match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } { + match self + .fd + .with(|fd| unsafe { write(fd, buf.as_ptr() as _, buf.len() as _) })? + { -1 => Err(io::Error::last_os_error()), n => Ok(n as usize), } @@ -81,12 +86,21 @@ impl IfaceStream { self.write(src) } - pub fn read<'a>(&self, dst: &'a mut [u8]) -> std::io::Result { - match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } { + pub fn read(&self, dst: &mut [u8]) -> std::io::Result { + // We don't read(or write) again from the fd because the given fd number might be reclaimed + // so this could make an spurious read/write to another fd and we DEFINITELY don't want that + match self + .fd + .with(|fd| unsafe { read(fd, dst.as_mut_ptr() as _, dst.len()) })? + { -1 => Err(io::Error::last_os_error()), n => Ok(n as usize), } } + + pub fn close(&self) { + self.fd.close(); + } } impl IfaceDevice { @@ -94,13 +108,17 @@ impl IfaceDevice { config: &InterfaceConfig, callbacks: &CallbackErrorFacade, ) -> Result<(Self, Arc>)> { - let fd = callbacks.on_set_interface_config( - config.ipv4, - config.ipv6, - DNS_SENTINEL, - DNS_FALLBACK_STRATEGY.to_string(), - )?; - let iface_stream = Arc::new(AsyncFd::new(IfaceStream { fd: fd.into() })?); + let fd = callbacks + .on_set_interface_config( + config.ipv4, + config.ipv6, + DNS_SENTINEL, + DNS_FALLBACK_STRATEGY.to_string(), + )? + .ok_or(Error::NoFd)?; + let iface_stream = Arc::new(AsyncFd::new(IfaceStream { + fd: Closeable::new(fd.into()), + })?); let this = Self(Arc::clone(&iface_stream)); Ok((this, iface_stream)) @@ -112,7 +130,12 @@ impl IfaceDevice { ifr_ifru: unsafe { std::mem::zeroed() }, }; - match unsafe { ioctl(self.0.get_ref().fd, TUNGETIFF as _, &mut ifr) } { + match self + .0 + .get_ref() + .fd + .with(|fd| unsafe { ioctl(fd, TUNGETIFF as _, &mut ifr) })? + { 0 => { let name_cstr = unsafe { std::ffi::CStr::from_ptr(ifr.ifr_name.as_ptr() as _) }; Ok(name_cstr.to_string_lossy().into_owned()) @@ -151,8 +174,15 @@ impl IfaceDevice { &self, route: IpNetwork, callbacks: &CallbackErrorFacade, - ) -> Result<()> { - callbacks.on_add_route(route) + ) -> Result>)>> { + self.0.get_ref().close(); + let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?; + let iface_stream = Arc::new(AsyncFd::new(IfaceStream { + fd: Closeable::new(fd.into()), + })?); + let this = Self(Arc::clone(&iface_stream)); + + Ok(Some((this, iface_stream))) } pub async fn up(&self) -> Result<()> { diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs b/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs index c2775a23b..6098177b6 100644 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs +++ b/rust/connlib/tunnel/src/device_channel/tun/tun_darwin.rs @@ -195,12 +195,12 @@ impl IfaceDevice { } if addr.sc_id == info.ctl_id { - let _ = callbacks.on_set_interface_config( + callbacks.on_set_interface_config( config.ipv4, config.ipv6, DNS_SENTINEL, "system_resolver".to_string(), - ); + )?; set_non_blocking(fd)?; @@ -241,8 +241,10 @@ impl IfaceDevice { &self, route: IpNetwork, callbacks: &CallbackErrorFacade, - ) -> Result<()> { - callbacks.on_add_route(route) + ) -> Result>)>> { + // This will always be None in macos + callbacks.on_add_route(route)?; + Ok(None) } pub async fn up(&self) -> Result<()> { diff --git a/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs index c525f04d5..96cbb2d67 100644 --- a/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs +++ b/rust/connlib/tunnel/src/device_channel/tun/tun_linux.rs @@ -175,7 +175,7 @@ impl IfaceDevice { &self, route: IpNetwork, _callbacks: &CallbackErrorFacade, - ) -> Result<()> { + ) -> Result>)>> { let req = self .handle .route() @@ -211,7 +211,7 @@ impl IfaceDevice { } */ - Ok(()) + Ok(None) } #[tracing::instrument(level = "trace", skip(self, _callbacks))] diff --git a/rust/connlib/tunnel/src/iface_handler.rs b/rust/connlib/tunnel/src/iface_handler.rs index 1b2d71c69..4c6a120dd 100644 --- a/rust/connlib/tunnel/src/iface_handler.rs +++ b/rust/connlib/tunnel/src/iface_handler.rs @@ -138,9 +138,9 @@ where ) -> Result<()> { if let Some(r) = self.check_for_dns(src) { match r { - dns::SendPacket::Ipv4(r) => self.write4_device_infallible(device_writer, &r[..]), - dns::SendPacket::Ipv6(r) => self.write6_device_infallible(device_writer, &r[..]), - } + dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?, + dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?, + }; return Ok(()); } @@ -170,29 +170,30 @@ where device_io: DeviceIo, ) { let device_writer = device_io.clone(); + let mut src = [0u8; MAX_UDP_SIZE]; + let mut dst = [0u8; MAX_UDP_SIZE]; loop { - let mut src = [0u8; MAX_UDP_SIZE]; - let mut dst = [0u8; MAX_UDP_SIZE]; - let res = { - // TODO: We should check here if what we read is a whole packet - // there's no docs on tun device on when a whole packet is read, is it \n or another thing? - // found some comments saying that a single read syscall represents a single packet but no docs on that - // See https://stackoverflow.com/questions/18461365/how-to-read-packet-by-packet-from-linux-tun-tap - match device_io.read(&mut src[..iface_config.mtu()]).await { - Ok(res) => res, - Err(e) => { - tracing::error!(error = ?e, from = "iface", action = "read"); - let _ = self.callbacks.on_error(&e.into()); - continue; - } + let res = match device_io.read(&mut src[..iface_config.mtu()]).await { + Ok(res) => res, + Err(e) => { + tracing::error!(err = ?e, "failed to read interface: {e:#}"); + let _ = self.callbacks.on_error(&e.into()); + break; } }; - tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface"); - // TODO - let _ = self + + if res == 0 { + break; + } + + if let Err(e) = self .handle_iface_packet(&device_writer, &mut src[..res], &mut dst) - .await; + .await + { + let _ = self.callbacks.on_error(&e); + tracing::error!(err = ?e, "failed to handle packet {e:#}") + } } } } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 14676129d..10c7e4e56 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -18,7 +18,7 @@ use itertools::Itertools; use parking_lot::{Mutex, RwLock}; use peer::{Peer, PeerStats}; use resource_table::ResourceTable; -use tokio::time::MissedTickBehavior; +use tokio::{task::AbortHandle, time::MissedTickBehavior}; use webrtc::{ api::{ interceptor_registry::register_default_interceptors, media_engine::MediaEngine, @@ -140,15 +140,19 @@ struct AwaitingConnectionDetails { pub response_received: bool, } +#[derive(Clone)] +struct Device { + pub config: Arc, + pub io: DeviceIo, +} + // TODO: We should use newtypes for each kind of Id /// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets /// to communicate between peers. pub struct Tunnel { next_index: Mutex, - // We use a tokio's mutex here since it makes things easier and we only need it - // during init, so the performance hit is neglibile - iface_config: RwLock>>, - device_io: RwLock>, + // We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact + device: tokio::sync::RwLock>, rate_limiter: Arc, private_key: StaticSecret, public_key: PublicKey, @@ -164,6 +168,7 @@ pub struct Tunnel { control_signaler: C, gateway_public_keys: Mutex>, callbacks: CallbackErrorFacade, + iface_handler_abort: Mutex>, } // TODO: For now we only use these fields with debug @@ -249,9 +254,9 @@ where let gateway_public_keys = Default::default(); let resources_gateways = Default::default(); let gateway_awaiting_connection = Default::default(); - let iface_config = Default::default(); - let device_io = Default::default(); + let device = Default::default(); let ice_candidate_queue = Default::default(); + let iface_handler_abort = Default::default(); // ICE let mut media_engine = MediaEngine::default(); @@ -287,31 +292,54 @@ where next_index, webrtc_api, resources, - iface_config, - device_io, + device, awaiting_connection, gateway_awaiting_connection, control_signaler, resources_gateways, ice_candidate_queue, callbacks: CallbackErrorFacade(callbacks), + iface_handler_abort, }) } + #[tracing::instrument(level = "trace", skip(self))] + pub async fn add_route(self: &Arc, route: IpNetwork) -> Result<()> { + let mut device = self.device.write().await; + + if let Some(new_device) = device + .as_ref() + .ok_or(Error::ControlProtocolError)? + .config + .add_route(route, self.callbacks()) + .await? + { + *device = Some(new_device.clone()); + let dev = Arc::clone(self); + self.iface_handler_abort.lock().replace( + tokio::spawn( + async move { dev.iface_handler(new_device.config, new_device.io).await }, + ) + .abort_handle(), + ); + } + + Ok(()) + } + /// Adds a the given resource to the tunnel. /// /// Once added, when a packet for the resource is intercepted a new data channel will be created /// and packets will be wrapped with wireguard and sent through it. #[tracing::instrument(level = "trace", skip(self))] - pub async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> { + pub async fn add_resource( + self: &Arc, + resource_description: ResourceDescription, + ) -> Result<()> { let mut any_valid_route = false; { - let Some(iface_config) = self.iface_config.read().clone() else { - tracing::error!("add_resource_before_initialization"); - return Err(Error::ControlProtocolError); - }; for ip in resource_description.ips() { - if let Err(e) = iface_config.add_route(ip, self.callbacks()).await { + if let Err(e) = self.add_route(ip).await { tracing::warn!(route = %ip, error = ?e, "add_route"); let _ = self.callbacks().on_error(&e); } else { @@ -336,17 +364,17 @@ where /// Sets the interface configuration and starts background tasks. #[tracing::instrument(level = "trace", skip(self))] pub async fn set_interface(self: &Arc, config: &InterfaceConfig) -> Result<()> { - let (iface_config, device_io) = create_iface(config, self.callbacks()).await?; - iface_config - .add_route(DNS_SENTINEL.into(), self.callbacks()) - .await?; - let iface_config = Arc::new(iface_config); + let device = create_iface(config, self.callbacks()).await?; + *self.device.write().await = Some(device.clone()); - *self.device_io.write() = Some(device_io.clone()); - *self.iface_config.write() = Some(Arc::clone(&iface_config)); - self.start_timers()?; + self.start_timers().await?; let dev = Arc::clone(self); - tokio::spawn(async move { dev.iface_handler(iface_config, device_io).await }); + *self.iface_handler_abort.lock() = Some( + tokio::spawn(async move { dev.iface_handler(device.config, device.io).await }) + .abort_handle(), + ); + + self.add_route(DNS_SENTINEL.into()).await?; self.callbacks.on_tunnel_ready()?; @@ -453,17 +481,22 @@ where }); } - fn start_refresh_mtu_timer(self: &Arc) -> Result<()> { - let Some(iface_config) = self.iface_config.read().clone() else { - return Err(Error::NoIface); - }; + async fn start_refresh_mtu_timer(self: &Arc) -> Result<()> { + let dev = self.clone(); let callbacks = self.callbacks().clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL); interval.set_missed_tick_behavior(MissedTickBehavior::Delay); loop { interval.tick().await; - if let Err(e) = iface_config.refresh_mtu().await { + + let Some(device) = dev.device.read().await.clone() else { + let err = Error::ControlProtocolError; + tracing::error!(?err, "get_iface_config"); + let _ = callbacks.0.on_error(&err); + continue; + }; + if let Err(e) = device.config.refresh_mtu().await { tracing::error!(error = ?e, "refresh_mtu"); let _ = callbacks.0.on_error(&e); } @@ -473,29 +506,13 @@ where Ok(()) } - fn start_timers(self: &Arc) -> Result<()> { - self.start_refresh_mtu_timer()?; + async fn start_timers(self: &Arc) -> Result<()> { + self.start_refresh_mtu_timer().await?; self.start_rate_limiter_refresh_timer(); self.start_peers_refresh_timer(); Ok(()) } - #[inline(always)] - fn write4_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) { - if let Err(e) = device_io.write4(packet) { - tracing::error!(?e, "iface_write"); - let _ = self.callbacks().on_error(&e.into()); - } - } - - #[inline(always)] - fn write6_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) { - if let Err(e) = device_io.write6(packet) { - tracing::error!(?e, "iface_write"); - let _ = self.callbacks().on_error(&e.into()); - } - } - fn get_resource(&self, buff: &[u8]) -> Option { let addr = Tunn::dst_address(buff)?; let resources = self.resources.read(); diff --git a/rust/connlib/tunnel/src/peer_handler.rs b/rust/connlib/tunnel/src/peer_handler.rs index 2fdd6f703..d6c57a155 100644 --- a/rust/connlib/tunnel/src/peer_handler.rs +++ b/rust/connlib/tunnel/src/peer_handler.rs @@ -66,26 +66,26 @@ where peer: &Arc, device_io: &DeviceIo, decapsulate_result: TunnResult<'a>, - ) -> bool { + ) -> Result { match decapsulate_result { - TunnResult::Done => false, + TunnResult::Done => Ok(false), TunnResult::Err(e) => { tracing::error!(error = ?e, "decapsulate_packet"); let _ = self.callbacks().on_error(&e.into()); - false + Ok(false) } TunnResult::WriteToNetwork(packet) => { let bytes = Bytes::copy_from_slice(packet); peer.send_infallible(bytes, &self.callbacks).await; - true + Ok(true) } TunnResult::WriteToTunnelV4(packet, addr) => { - self.send_to_resource(device_io, peer, addr.into(), packet); - false + self.send_to_resource(device_io, peer, addr.into(), packet)?; + Ok(false) } TunnResult::WriteToTunnelV6(packet, addr) => { - self.send_to_resource(device_io, peer, addr.into(), packet); - false + self.send_to_resource(device_io, peer, addr.into(), packet)?; + Ok(false) } } } @@ -108,7 +108,7 @@ where if self .handle_decapsulated_packet(peer, device_writer, decapsulate_result) - .await + .await? { // Flush pending queue while let TunnResult::WriteToNetwork(packet) = { @@ -125,10 +125,16 @@ where Ok(()) } - pub(crate) async fn peer_handler(self: &Arc, peer: Arc, device_io: DeviceIo) { + async fn peer_handler( + self: &Arc, + peer: &Arc, + device_io: DeviceIo, + ) -> std::io::Result<()> { let mut src_buf = [0u8; MAX_UDP_SIZE]; let mut dst_buf = [0u8; MAX_UDP_SIZE]; while let Ok(size) = peer.channel.read(&mut src_buf[..]).await { + tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); + // TODO: Double check that this can only happen on closed channel // I think it's possible to transmit a 0-byte message through the channel // but we would never use that. @@ -137,14 +143,38 @@ where break; } - tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer"); - let _ = self - .handle_peer_packet(&peer, &device_io, &src_buf[..size], &mut dst_buf) - .await; + if let Err(Error::Io(e)) = self + .handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf) + .await + { + return Err(e); + } } - let peer_stats = peer.stats(); - tracing::debug!(peer = ?peer_stats, "peer_stopped"); + Ok(()) + } + + pub(crate) async fn start_peer_handler(self: &Arc, peer: Arc) { + loop { + let Some(device) = self.device.read().await.clone() else { + let err = Error::NoIface; + tracing::error!(?err); + let _ = self.callbacks().on_disconnect(Some(&err)); + break; + }; + let device_io = device.io; + + if let Err(err) = self.peer_handler(&peer, device_io).await { + if err.raw_os_error() != Some(9) { + tracing::error!(?err); + let _ = self.callbacks().on_error(&err.into()); + break; + } else { + tracing::warn!("bad_file_descriptor"); + } + } + } + tracing::debug!(peer = ?peer.stats(), "peer_stopped"); self.stop_peer(peer.index, peer.conn_id).await; } } diff --git a/rust/connlib/tunnel/src/resource_sender.rs b/rust/connlib/tunnel/src/resource_sender.rs index 13a4d6749..f0eb9d8d0 100644 --- a/rust/connlib/tunnel/src/resource_sender.rs +++ b/rust/connlib/tunnel/src/resource_sender.rs @@ -24,15 +24,17 @@ where } #[inline(always)] - fn send_packet(&self, device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) { + fn send_packet( + &self, + device_io: &DeviceIo, + packet: &mut [u8], + dst_addr: IpAddr, + ) -> std::io::Result<()> { match dst_addr { - IpAddr::V4(_) => { - self.write4_device_infallible(device_io, packet); - } - IpAddr::V6(_) => { - self.write6_device_infallible(device_io, packet); - } - } + IpAddr::V4(_) => device_io.write4(packet)?, + IpAddr::V6(_) => device_io.write6(packet)?, + }; + Ok(()) } #[inline(always)] @@ -42,26 +44,20 @@ where peer: &Arc, addr: IpAddr, packet: &mut [u8], - ) { + ) -> Result<()> { let Some((dst, resource)) = peer.get_packet_resource(packet) else { // If there's no associated resource it means that we are in a client, then the packet comes from a gateway // and we just trust gateways. // In gateways this should never happen. tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len()); - self.send_packet(device_io, packet, addr); - return; + self.send_packet(device_io, packet, addr)?; + return Ok(()); }; - match get_resource_addr_and_port(peer, &resource, &addr, &dst) { - Ok((dst_addr, _dst_port)) => { - self.update_packet(packet, dst_addr); - self.send_packet(device_io, packet, addr); - } - Err(e) => { - tracing::error!(err = ?e, "resource_parse"); - let _ = self.callbacks().on_error(&e); - } - } + let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?; + self.update_packet(packet, dst_addr); + self.send_packet(device_io, packet, addr)?; + Ok(()) } pub(crate) fn send_to_resource( @@ -70,11 +66,13 @@ where peer: &Arc, addr: IpAddr, packet: &mut [u8], - ) { + ) -> Result<()> { if peer.is_allowed(addr) { - self.packet_allowed(device_io, peer, addr, packet); + self.packet_allowed(device_io, peer, addr, packet)?; + Ok(()) } else { tracing::warn!(%addr, "Received packet from peer with an unallowed ip"); + Ok(()) } } } @@ -90,7 +88,6 @@ fn get_resource_addr_and_port( dst: &IpAddr, ) -> Result<(IpAddr, Option)> { match resource { - // Note: for now no translation is needed for the ip since we do a peer/connection per resource ResourceDescription::Dns(r) => { let mut address = r.address.split(':'); let Some(dst_addr) = address.next() else {