connlib+android: enable fd replacement (#2235)

Should be easier to review commit by commit.

The gist of this commit is:
* `onAddRoute` on Android now takes an address+prefix as to minimize
parsing
* `onAddRoute` recreates the vpn service each time(TODO: is this too bad
for performance?)
* `on_add_route` and `onAddRoute` returns the new fd
* on android after `on_add_route` we recreate `IfaceConfig` and
`DeviceIo` and we store the new values
* `peer_handler` now runs on a loop, where each time we fail a write
with an error code 9(bad descriptor) we try to take the new `DeviceIo`
* we keep an
[`AbortHandle`](https://docs.rs/tokio/latest/tokio/task/struct.AbortHandle.html)
from the `iface_handler` task, since closing the fd doesn't awake the
`read` task for `AsyncFd`(I tried it, right now `close` is only called
after dropping the fd) so we explicitly abort the task and start a new
one with the new `device_io`.
* in android `DeviceIo` has an atomic which tells if it's closed or open
and we change it to closed after `on_add_route`, we use this as to never
double-close the fd, instead we wait until it's dropped. This *might*
affect performance on android since we use non-`Ordering::Relaxed`
atomic operation each read/write but it won't affect perfromance in
other platforms, furthermore I believe the performance gains if we
remove this will be minimal.

Fixes #2227

---------

Co-authored-by: Jamil <jamilbk@users.noreply.github.com>
This commit is contained in:
Gabi
2023-10-08 23:52:45 -03:00
committed by GitHub
parent f08e7bb5be
commit e516bcc8dd
25 changed files with 378 additions and 210 deletions

View File

@@ -1,7 +1,6 @@
plugins {
id("com.android.application")
id("kotlin-android")
id("kotlin-kapt")
id("dagger.hilt.android.plugin")
id("kotlin-parcelize")
id("androidx.navigation.safeargs")
@@ -9,6 +8,7 @@ plugins {
id("com.google.gms.google-services")
id("com.google.firebase.crashlytics")
id("com.diffplug.spotless") version "6.22.0"
id("kotlin-kapt")
}
spotless {

View File

@@ -20,6 +20,7 @@ import dev.firezone.android.core.domain.preference.GetConfigUseCase
import dev.firezone.android.core.presentation.MainActivity
import dev.firezone.android.tunnel.callback.ConnlibCallback
import dev.firezone.android.tunnel.data.TunnelRepository
import dev.firezone.android.tunnel.model.Cidr
import dev.firezone.android.tunnel.model.Resource
import dev.firezone.android.tunnel.model.Tunnel
import dev.firezone.android.tunnel.model.TunnelConfig
@@ -99,16 +100,26 @@ class TunnelService : VpnService() {
return true
}
override fun onAddRoute(cidrAddress: String) {
Log.d(TAG, "onAddRoute: $cidrAddress")
tunnelRepository.addRoute(cidrAddress)
override fun onAddRoute(
addr: String,
prefix: Int,
): Int {
Log.d(TAG, "onAddRoute: $addr/$prefix")
val route = Cidr(addr, prefix)
tunnelRepository.addRoute(route)
val fd = buildVpnService().establish()?.detachFd() ?: -1
protect(fd)
return fd
}
override fun onRemoveRoute(cidrAddress: String) {
Log.d(TAG, "onRemoveRoute: $cidrAddress")
override fun onRemoveRoute(
addr: String,
prefix: Int,
) {
Log.d(TAG, "onRemoveRoute: $addr/$prefix")
val cidr = Cidr(addr, prefix)
tunnelRepository.removeRoute(cidrAddress)
tunnelRepository.removeRoute(cidr)
}
override fun getSystemDefaultResolvers(): String {
@@ -231,16 +242,9 @@ class TunnelService : VpnService() {
addDnsServer(tunnel.config.dnsAddress)
/*tunnel.routes.forEach {
addRoute(it, 32)
}*/
// TODO: These are the staging Resources. Remove these in favor of the onUpdateResources callback.
addRoute("100.100.111.1", 32)
addRoute("172.31.82.179", 32)
addRoute("172.31.83.10", 32)
addRoute("172.31.92.238", 32)
addRoute("172.31.93.123", 32)
tunnel.routes.forEach {
addRoute(it.address, it.prefix)
}
setSession(SESSION_NAME)

View File

@@ -11,9 +11,15 @@ interface ConnlibCallback {
fun onTunnelReady(): Boolean
fun onAddRoute(cidrAddress: String)
fun onAddRoute(
cidrAddress: String,
prefix: Int,
): Int
fun onRemoveRoute(cidrAddress: String)
fun onRemoveRoute(
addr: String,
prefix: Int,
)
fun onUpdateResources(resourceListJSON: String)

View File

@@ -2,6 +2,7 @@
package dev.firezone.android.tunnel.data
import android.content.SharedPreferences
import dev.firezone.android.tunnel.model.Cidr
import dev.firezone.android.tunnel.model.Resource
import dev.firezone.android.tunnel.model.Tunnel
import dev.firezone.android.tunnel.model.TunnelConfig
@@ -21,11 +22,11 @@ interface TunnelRepository {
fun getResources(): List<Resource>
fun addRoute(route: String)
fun addRoute(route: Cidr)
fun removeRoute(route: String)
fun removeRoute(route: Cidr)
fun getRoutes(): List<String>
fun getRoutes(): List<Cidr>
fun clear()

View File

@@ -8,6 +8,7 @@ import dev.firezone.android.tunnel.data.TunnelRepository.Companion.CONFIG_KEY
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.RESOURCES_KEY
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.ROUTES_KEY
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.STATE_KEY
import dev.firezone.android.tunnel.model.Cidr
import dev.firezone.android.tunnel.model.Resource
import dev.firezone.android.tunnel.model.Tunnel
import dev.firezone.android.tunnel.model.TunnelConfig
@@ -83,30 +84,30 @@ class TunnelRepositoryImpl
}
}
override fun addRoute(route: String) {
override fun addRoute(route: Cidr) {
synchronized(lock) {
getRoutes().toMutableList().run {
add(route)
val json = moshi.adapter<List<String>>().toJson(this)
val json = moshi.adapter<List<Cidr>>().toJson(this)
sharedPreferences.edit().putString(ROUTES_KEY, json).apply()
}
}
}
override fun removeRoute(route: String) {
override fun removeRoute(route: Cidr) {
synchronized(lock) {
getRoutes().toMutableList().run {
remove(route)
val json = moshi.adapter<List<String>>().toJson(this)
val json = moshi.adapter<List<Cidr>>().toJson(this)
sharedPreferences.edit().putString(ROUTES_KEY, json).apply()
}
}
}
override fun getRoutes(): List<String> =
override fun getRoutes(): List<Cidr> =
synchronized(lock) {
val json = sharedPreferences.getString(ROUTES_KEY, "[]") ?: "[]"
return moshi.adapter<List<String>>().fromJson(json) ?: emptyList()
return moshi.adapter<List<Cidr>>().fromJson(json) ?: emptyList()
}
override fun clear() {

View File

@@ -0,0 +1,14 @@
/* Licensed under Apache 2.0 (C) 2023 Firezone, Inc. */
package dev.firezone.android.tunnel.model
import android.os.Parcelable
import com.squareup.moshi.JsonClass
import kotlinx.parcelize.Parcelize
@JsonClass(generateAdapter = true)
@Parcelize
data class Cidr(
// TODO: Not convinced of using String to store address, we can make a moshi InetAddress adapter
val address: String,
val prefix: Int,
) : Parcelable

View File

@@ -10,7 +10,7 @@ import kotlinx.parcelize.Parcelize
data class Tunnel(
val config: TunnelConfig = TunnelConfig(),
var state: State = State.Down,
val routes: List<String> = emptyList(),
val routes: List<Cidr> = emptyList(),
val resources: List<Resource> = emptyList(),
) : Parcelable {
enum class State {

View File

@@ -122,8 +122,8 @@ fn init_logging(log_dir: &Path, log_filter: String) -> file_logger::Handle {
.expect("Logging guard should never be initialized twice");
let _ = tracing_subscriber::registry()
.with(file_layer.with_filter(EnvFilter::new(log_filter)))
.with(android_layer())
.with(file_layer.with_filter(EnvFilter::new(log_filter.clone())))
.with(android_layer().with_filter(EnvFilter::new(log_filter.clone())))
.try_init();
handle
@@ -138,7 +138,7 @@ impl Callbacks for CallbackHandler {
tunnel_address_v6: Ipv6Addr,
dns_address: Ipv4Addr,
dns_fallback_strategy: String,
) -> Result<RawFd, Self::Error> {
) -> Result<Option<RawFd>, Self::Error> {
self.env(|mut env| {
let tunnel_address_v4 =
env.new_string(tunnel_address_v4.to_string())
@@ -179,6 +179,7 @@ impl Callbacks for CallbackHandler {
],
)
.and_then(|val| val.i())
.map(Some)
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
}
@@ -195,27 +196,31 @@ impl Callbacks for CallbackHandler {
})
}
fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
self.env(|mut env| {
let route = env.new_string(route.to_string()).map_err(|source| {
CallbackError::NewStringFailed {
name: "route",
let ip = env
.new_string(route.network_address().to_string())
.map_err(|source| CallbackError::NewStringFailed {
name: "route_ip",
source,
}
})?;
call_method(
&mut env,
})?;
let name = "onAddRoute";
env.call_method(
&self.callback_handler,
"onAddRoute",
"(Ljava/lang/String;)V",
&[JValue::from(&route)],
name,
"(Ljava/lang/String;I)I",
&[JValue::from(&ip), JValue::Int(route.netmask().into())],
)
.and_then(|val| val.i())
.map(Some)
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
}
fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
self.env(|mut env| {
let route = env.new_string(route.to_string()).map_err(|source| {
let ip = env.new_string(route.to_string()).map_err(|source| {
CallbackError::NewStringFailed {
name: "route",
source,
@@ -225,8 +230,8 @@ impl Callbacks for CallbackHandler {
&mut env,
&self.callback_handler,
"onRemoveRoute",
"(Ljava/lang/String;)V",
&[JValue::from(&route)],
"(Ljava/lang/String;I)V",
&[JValue::from(&ip), JValue::Int(route.netmask().into())],
)
})
}
@@ -301,7 +306,7 @@ impl Callbacks for CallbackHandler {
fn throw(env: &mut JNIEnv, class: &str, msg: impl Into<JNIString>) {
if let Err(err) = env.throw_new(class, msg) {
// We can't panic, since unwinding across the FFI boundary is UB...
tracing::error!("failed to throw Java exception: {err}");
tracing::error!(?err, "failed to throw Java exception");
}
}

View File

@@ -89,14 +89,14 @@ impl Callbacks for CallbackHandler {
tunnel_address_v6: Ipv6Addr,
dns_address: Ipv4Addr,
dns_fallback_strategy: String,
) -> Result<RawFd, Self::Error> {
) -> Result<Option<RawFd>, Self::Error> {
self.inner.on_set_interface_config(
tunnel_address_v4.to_string(),
tunnel_address_v6.to_string(),
dns_address.to_string(),
dns_fallback_strategy.to_string(),
);
Ok(-1)
Ok(None)
}
fn on_tunnel_ready(&self) -> Result<(), Self::Error> {
@@ -104,9 +104,9 @@ impl Callbacks for CallbackHandler {
Ok(())
}
fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
self.inner.on_add_route(route.to_string());
Ok(())
Ok(None)
}
fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> {

View File

@@ -78,7 +78,11 @@ where
runtime_stopper: tx.clone(),
callbacks,
};
// In android we get an stack-overflow due to tokio
// taking too much of the stack-space:
// See: https://github.com/firezone/firezone/issues/2227
let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(3 * 1024 * 1024)
.enable_all()
.build()?;
{

View File

@@ -14,14 +14,17 @@ pub trait Callbacks: Clone + Send + Sync {
type Error: Debug + Display + Error;
/// Called when the tunnel address is set.
///
/// This should return a new `fd` if there is one.
/// (Only happens on android for now)
fn on_set_interface_config(
&self,
_: Ipv4Addr,
_: Ipv6Addr,
_: Ipv4Addr,
_: String,
) -> Result<RawFd, Self::Error> {
Ok(-1)
) -> Result<Option<RawFd>, Self::Error> {
Ok(None)
}
/// Called when the tunnel is connected.
@@ -31,8 +34,11 @@ pub trait Callbacks: Clone + Send + Sync {
}
/// Called when when a route is added.
fn on_add_route(&self, _: IpNetwork) -> Result<(), Self::Error> {
Ok(())
///
/// This should return a new `fd` if there is one.
/// (Only happens on android for now)
fn on_add_route(&self, _: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
Ok(None)
}
/// Called when when a route is removed.

View File

@@ -18,7 +18,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
tunnel_address_v6: Ipv6Addr,
dns_address: Ipv4Addr,
dns_fallback_strategy: String,
) -> Result<RawFd> {
) -> Result<Option<RawFd>> {
let result = self
.0
.on_set_interface_config(
@@ -29,7 +29,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
)
.map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
tracing::error!(?err);
}
result
}
@@ -40,18 +40,18 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
.on_tunnel_ready()
.map_err(|err| Error::OnTunnelReadyFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
tracing::error!(?err);
}
result
}
fn on_add_route(&self, route: IpNetwork) -> Result<()> {
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>> {
let result = self
.0
.on_add_route(route)
.map_err(|err| Error::OnAddRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
tracing::error!(?err);
}
result
}
@@ -62,7 +62,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
.on_remove_route(route)
.map_err(|err| Error::OnRemoveRouteFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
tracing::error!(?err);
}
result
}
@@ -73,14 +73,14 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
.on_update_resources(resource_list)
.map_err(|err| Error::OnUpdateResourcesFailed(err.to_string()));
if let Err(err) = result.as_ref() {
tracing::error!("{err}");
tracing::error!(?err);
}
result
}
fn on_disconnect(&self, error: Option<&Error>) -> Result<()> {
if let Err(err) = self.0.on_disconnect(error) {
tracing::error!("`on_disconnect` failed: {err}");
tracing::error!(?err, "`on_disconnect` failed");
}
// There's nothing we can really do if `on_disconnect` fails.
Ok(())
@@ -88,7 +88,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
fn on_error(&self, error: &Error) -> Result<()> {
if let Err(err) = self.0.on_error(error) {
tracing::error!("`on_error` failed: {err}");
tracing::error!(?err, "`on_error` failed");
}
// There's nothing we really want to do if `on_error` fails.
Ok(())

View File

@@ -96,6 +96,9 @@ pub enum ConnlibError {
/// No iface found
#[error("No iface found")]
NoIface,
/// Expected file descriptor and none was found
#[error("No filedescriptor")]
NoFd,
/// No MTU found
#[error("No MTU found")]
NoMtu,
@@ -120,6 +123,9 @@ pub enum ConnlibError {
/// Invalid source address for peer
#[error("Invalid source address")]
InvalidSource,
/// Any parse error
#[error("parse error")]
ParseError,
}
impl ConnlibError {

View File

@@ -145,12 +145,8 @@ where
})
});
let Some(device_io) = self.device_io.read().clone() else {
return Err(Error::NoIface);
};
let tunnel = Arc::clone(self);
tokio::spawn(async move { tunnel.peer_handler(peer, device_io).await });
tokio::spawn(async move { tunnel.start_peer_handler(peer).await });
Ok(())
}

View File

@@ -96,12 +96,13 @@ where
tracing::trace!("new_data_channel_open");
Box::pin(async move {
{
let Some(iface_config) = tunnel.iface_config.read().clone() else {
let Some(device) = tunnel.device.read().await.clone() else {
let e = Error::NoIface;
tracing::error!(err = ?e, "channel_open");
let _ = tunnel.callbacks().on_error(&e);
return;
};
let iface_config = device.config;
for &ip in &peer.ips {
if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await
{

View File

@@ -9,6 +9,8 @@ use tokio::io::{unix::AsyncFd, Interest};
use tun::{IfaceDevice, IfaceStream};
use crate::Device;
mod tun;
pub(crate) struct IfaceConfig {
@@ -53,23 +55,32 @@ impl IfaceConfig {
&self,
route: IpNetwork,
callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<()> {
self.iface.add_route(route, callbacks).await
) -> Result<Option<Device>> {
let Some((iface, stream)) = self.iface.add_route(route, callbacks).await? else {
return Ok(None);
};
let io = DeviceIo(stream);
let mtu = iface.mtu().await?;
let config = Arc::new(IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
});
Ok(Some(Device { io, config }))
}
}
pub(crate) async fn create_iface(
config: &Interface,
callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<(IfaceConfig, DeviceIo)> {
) -> Result<Device> {
let (iface, stream) = IfaceDevice::new(config, callbacks).await?;
iface.up().await?;
let device_io = DeviceIo(stream);
let io = DeviceIo(stream);
let mtu = iface.mtu().await?;
let iface_config = IfaceConfig {
let config = Arc::new(IfaceConfig {
iface,
mtu: AtomicUsize::new(mtu),
};
});
Ok((iface_config, device_io))
Ok(Device { config, io })
}

View File

@@ -1,3 +1,4 @@
use crate::Device;
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
use ip_network::IpNetwork;
@@ -33,7 +34,7 @@ impl IfaceConfig {
&self,
_: IpNetwork,
_: &CallbackErrorFacade<impl Callbacks>,
) -> Result<()> {
) -> Result<Option<Device>> {
todo!()
}
}
@@ -41,6 +42,6 @@ impl IfaceConfig {
pub(crate) async fn create_iface(
_: &Interface,
_: &CallbackErrorFacade<impl Callbacks>,
) -> Result<(IfaceConfig, DeviceIo)> {
) -> Result<Device> {
todo!()
}

View File

@@ -0,0 +1,35 @@
use std::os::fd::{AsRawFd, RawFd};
use std::sync::atomic::{AtomicBool, Ordering};
#[derive(Debug)]
pub(crate) struct Closeable {
closed: AtomicBool,
value: RawFd,
}
impl AsRawFd for Closeable {
fn as_raw_fd(&self) -> RawFd {
self.value.as_raw_fd()
}
}
impl Closeable {
pub(crate) fn new(fd: RawFd) -> Self {
Self {
closed: AtomicBool::new(false),
value: fd,
}
}
pub(crate) fn with<U>(&self, f: impl FnOnce(RawFd) -> U) -> std::io::Result<U> {
if self.closed.load(Ordering::Acquire) {
return Err(std::io::Error::from_raw_os_error(9));
}
Ok(f(self.value))
}
pub(crate) fn close(&self) {
self.closed.store(true, Ordering::Release);
}
}

View File

@@ -1,4 +1,5 @@
use crate::InterfaceConfig;
use closeable::Closeable;
use connlib_shared::{CallbackErrorFacade, Callbacks, Error, Result, DNS_SENTINEL};
use ip_network::IpNetwork;
use libc::{
@@ -13,6 +14,7 @@ use std::{
};
use tokio::io::unix::AsyncFd;
mod closeable;
mod wrapped_socket;
// Android doesn't support Split DNS. So we intercept all requests and forward
// the non-Firezone name resolution requests to the upstream DNS resolver.
@@ -50,24 +52,27 @@ pub(crate) struct IfaceDevice(Arc<AsyncFd<IfaceStream>>);
#[derive(Debug)]
pub(crate) struct IfaceStream {
fd: RawFd,
fd: Closeable,
}
impl AsRawFd for IfaceStream {
fn as_raw_fd(&self) -> RawFd {
self.fd
self.fd.as_raw_fd()
}
}
impl Drop for IfaceStream {
fn drop(&mut self) {
unsafe { close(self.fd) };
unsafe { close(self.fd.as_raw_fd()) };
}
}
impl IfaceStream {
fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } {
match self
.fd
.with(|fd| unsafe { write(fd, buf.as_ptr() as _, buf.len() as _) })?
{
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
@@ -81,12 +86,21 @@ impl IfaceStream {
self.write(src)
}
pub fn read<'a>(&self, dst: &'a mut [u8]) -> std::io::Result<usize> {
match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } {
pub fn read(&self, dst: &mut [u8]) -> std::io::Result<usize> {
// We don't read(or write) again from the fd because the given fd number might be reclaimed
// so this could make an spurious read/write to another fd and we DEFINITELY don't want that
match self
.fd
.with(|fd| unsafe { read(fd, dst.as_mut_ptr() as _, dst.len()) })?
{
-1 => Err(io::Error::last_os_error()),
n => Ok(n as usize),
}
}
pub fn close(&self) {
self.fd.close();
}
}
impl IfaceDevice {
@@ -94,13 +108,17 @@ impl IfaceDevice {
config: &InterfaceConfig,
callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<(Self, Arc<AsyncFd<IfaceStream>>)> {
let fd = callbacks.on_set_interface_config(
config.ipv4,
config.ipv6,
DNS_SENTINEL,
DNS_FALLBACK_STRATEGY.to_string(),
)?;
let iface_stream = Arc::new(AsyncFd::new(IfaceStream { fd: fd.into() })?);
let fd = callbacks
.on_set_interface_config(
config.ipv4,
config.ipv6,
DNS_SENTINEL,
DNS_FALLBACK_STRATEGY.to_string(),
)?
.ok_or(Error::NoFd)?;
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
fd: Closeable::new(fd.into()),
})?);
let this = Self(Arc::clone(&iface_stream));
Ok((this, iface_stream))
@@ -112,7 +130,12 @@ impl IfaceDevice {
ifr_ifru: unsafe { std::mem::zeroed() },
};
match unsafe { ioctl(self.0.get_ref().fd, TUNGETIFF as _, &mut ifr) } {
match self
.0
.get_ref()
.fd
.with(|fd| unsafe { ioctl(fd, TUNGETIFF as _, &mut ifr) })?
{
0 => {
let name_cstr = unsafe { std::ffi::CStr::from_ptr(ifr.ifr_name.as_ptr() as _) };
Ok(name_cstr.to_string_lossy().into_owned())
@@ -151,8 +174,15 @@ impl IfaceDevice {
&self,
route: IpNetwork,
callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<()> {
callbacks.on_add_route(route)
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
self.0.get_ref().close();
let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?;
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
fd: Closeable::new(fd.into()),
})?);
let this = Self(Arc::clone(&iface_stream));
Ok(Some((this, iface_stream)))
}
pub async fn up(&self) -> Result<()> {

View File

@@ -195,12 +195,12 @@ impl IfaceDevice {
}
if addr.sc_id == info.ctl_id {
let _ = callbacks.on_set_interface_config(
callbacks.on_set_interface_config(
config.ipv4,
config.ipv6,
DNS_SENTINEL,
"system_resolver".to_string(),
);
)?;
set_non_blocking(fd)?;
@@ -241,8 +241,10 @@ impl IfaceDevice {
&self,
route: IpNetwork,
callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<()> {
callbacks.on_add_route(route)
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
// This will always be None in macos
callbacks.on_add_route(route)?;
Ok(None)
}
pub async fn up(&self) -> Result<()> {

View File

@@ -175,7 +175,7 @@ impl IfaceDevice {
&self,
route: IpNetwork,
_callbacks: &CallbackErrorFacade<impl Callbacks>,
) -> Result<()> {
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
let req = self
.handle
.route()
@@ -211,7 +211,7 @@ impl IfaceDevice {
}
*/
Ok(())
Ok(None)
}
#[tracing::instrument(level = "trace", skip(self, _callbacks))]

View File

@@ -138,9 +138,9 @@ where
) -> Result<()> {
if let Some(r) = self.check_for_dns(src) {
match r {
dns::SendPacket::Ipv4(r) => self.write4_device_infallible(device_writer, &r[..]),
dns::SendPacket::Ipv6(r) => self.write6_device_infallible(device_writer, &r[..]),
}
dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?,
dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?,
};
return Ok(());
}
@@ -170,29 +170,30 @@ where
device_io: DeviceIo,
) {
let device_writer = device_io.clone();
let mut src = [0u8; MAX_UDP_SIZE];
let mut dst = [0u8; MAX_UDP_SIZE];
loop {
let mut src = [0u8; MAX_UDP_SIZE];
let mut dst = [0u8; MAX_UDP_SIZE];
let res = {
// TODO: We should check here if what we read is a whole packet
// there's no docs on tun device on when a whole packet is read, is it \n or another thing?
// found some comments saying that a single read syscall represents a single packet but no docs on that
// See https://stackoverflow.com/questions/18461365/how-to-read-packet-by-packet-from-linux-tun-tap
match device_io.read(&mut src[..iface_config.mtu()]).await {
Ok(res) => res,
Err(e) => {
tracing::error!(error = ?e, from = "iface", action = "read");
let _ = self.callbacks.on_error(&e.into());
continue;
}
let res = match device_io.read(&mut src[..iface_config.mtu()]).await {
Ok(res) => res,
Err(e) => {
tracing::error!(err = ?e, "failed to read interface: {e:#}");
let _ = self.callbacks.on_error(&e.into());
break;
}
};
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
// TODO
let _ = self
if res == 0 {
break;
}
if let Err(e) = self
.handle_iface_packet(&device_writer, &mut src[..res], &mut dst)
.await;
.await
{
let _ = self.callbacks.on_error(&e);
tracing::error!(err = ?e, "failed to handle packet {e:#}")
}
}
}
}

View File

@@ -18,7 +18,7 @@ use itertools::Itertools;
use parking_lot::{Mutex, RwLock};
use peer::{Peer, PeerStats};
use resource_table::ResourceTable;
use tokio::time::MissedTickBehavior;
use tokio::{task::AbortHandle, time::MissedTickBehavior};
use webrtc::{
api::{
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
@@ -140,15 +140,19 @@ struct AwaitingConnectionDetails {
pub response_received: bool,
}
#[derive(Clone)]
struct Device {
pub config: Arc<IfaceConfig>,
pub io: DeviceIo,
}
// TODO: We should use newtypes for each kind of Id
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
/// to communicate between peers.
pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
next_index: Mutex<IndexLfsr>,
// We use a tokio's mutex here since it makes things easier and we only need it
// during init, so the performance hit is neglibile
iface_config: RwLock<Option<Arc<IfaceConfig>>>,
device_io: RwLock<Option<DeviceIo>>,
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
device: tokio::sync::RwLock<Option<Device>>,
rate_limiter: Arc<RateLimiter>,
private_key: StaticSecret,
public_key: PublicKey,
@@ -164,6 +168,7 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
control_signaler: C,
gateway_public_keys: Mutex<HashMap<GatewayId, PublicKey>>,
callbacks: CallbackErrorFacade<CB>,
iface_handler_abort: Mutex<Option<AbortHandle>>,
}
// TODO: For now we only use these fields with debug
@@ -249,9 +254,9 @@ where
let gateway_public_keys = Default::default();
let resources_gateways = Default::default();
let gateway_awaiting_connection = Default::default();
let iface_config = Default::default();
let device_io = Default::default();
let device = Default::default();
let ice_candidate_queue = Default::default();
let iface_handler_abort = Default::default();
// ICE
let mut media_engine = MediaEngine::default();
@@ -287,31 +292,54 @@ where
next_index,
webrtc_api,
resources,
iface_config,
device_io,
device,
awaiting_connection,
gateway_awaiting_connection,
control_signaler,
resources_gateways,
ice_candidate_queue,
callbacks: CallbackErrorFacade(callbacks),
iface_handler_abort,
})
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_route(self: &Arc<Self>, route: IpNetwork) -> Result<()> {
let mut device = self.device.write().await;
if let Some(new_device) = device
.as_ref()
.ok_or(Error::ControlProtocolError)?
.config
.add_route(route, self.callbacks())
.await?
{
*device = Some(new_device.clone());
let dev = Arc::clone(self);
self.iface_handler_abort.lock().replace(
tokio::spawn(
async move { dev.iface_handler(new_device.config, new_device.io).await },
)
.abort_handle(),
);
}
Ok(())
}
/// Adds a the given resource to the tunnel.
///
/// Once added, when a packet for the resource is intercepted a new data channel will be created
/// and packets will be wrapped with wireguard and sent through it.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> {
pub async fn add_resource(
self: &Arc<Self>,
resource_description: ResourceDescription,
) -> Result<()> {
let mut any_valid_route = false;
{
let Some(iface_config) = self.iface_config.read().clone() else {
tracing::error!("add_resource_before_initialization");
return Err(Error::ControlProtocolError);
};
for ip in resource_description.ips() {
if let Err(e) = iface_config.add_route(ip, self.callbacks()).await {
if let Err(e) = self.add_route(ip).await {
tracing::warn!(route = %ip, error = ?e, "add_route");
let _ = self.callbacks().on_error(&e);
} else {
@@ -336,17 +364,17 @@ where
/// Sets the interface configuration and starts background tasks.
#[tracing::instrument(level = "trace", skip(self))]
pub async fn set_interface(self: &Arc<Self>, config: &InterfaceConfig) -> Result<()> {
let (iface_config, device_io) = create_iface(config, self.callbacks()).await?;
iface_config
.add_route(DNS_SENTINEL.into(), self.callbacks())
.await?;
let iface_config = Arc::new(iface_config);
let device = create_iface(config, self.callbacks()).await?;
*self.device.write().await = Some(device.clone());
*self.device_io.write() = Some(device_io.clone());
*self.iface_config.write() = Some(Arc::clone(&iface_config));
self.start_timers()?;
self.start_timers().await?;
let dev = Arc::clone(self);
tokio::spawn(async move { dev.iface_handler(iface_config, device_io).await });
*self.iface_handler_abort.lock() = Some(
tokio::spawn(async move { dev.iface_handler(device.config, device.io).await })
.abort_handle(),
);
self.add_route(DNS_SENTINEL.into()).await?;
self.callbacks.on_tunnel_ready()?;
@@ -453,17 +481,22 @@ where
});
}
fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
let Some(iface_config) = self.iface_config.read().clone() else {
return Err(Error::NoIface);
};
async fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
let dev = self.clone();
let callbacks = self.callbacks().clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
interval.tick().await;
if let Err(e) = iface_config.refresh_mtu().await {
let Some(device) = dev.device.read().await.clone() else {
let err = Error::ControlProtocolError;
tracing::error!(?err, "get_iface_config");
let _ = callbacks.0.on_error(&err);
continue;
};
if let Err(e) = device.config.refresh_mtu().await {
tracing::error!(error = ?e, "refresh_mtu");
let _ = callbacks.0.on_error(&e);
}
@@ -473,29 +506,13 @@ where
Ok(())
}
fn start_timers(self: &Arc<Self>) -> Result<()> {
self.start_refresh_mtu_timer()?;
async fn start_timers(self: &Arc<Self>) -> Result<()> {
self.start_refresh_mtu_timer().await?;
self.start_rate_limiter_refresh_timer();
self.start_peers_refresh_timer();
Ok(())
}
#[inline(always)]
fn write4_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) {
if let Err(e) = device_io.write4(packet) {
tracing::error!(?e, "iface_write");
let _ = self.callbacks().on_error(&e.into());
}
}
#[inline(always)]
fn write6_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) {
if let Err(e) = device_io.write6(packet) {
tracing::error!(?e, "iface_write");
let _ = self.callbacks().on_error(&e.into());
}
}
fn get_resource(&self, buff: &[u8]) -> Option<ResourceDescription> {
let addr = Tunn::dst_address(buff)?;
let resources = self.resources.read();

View File

@@ -66,26 +66,26 @@ where
peer: &Arc<Peer>,
device_io: &DeviceIo,
decapsulate_result: TunnResult<'a>,
) -> bool {
) -> Result<bool> {
match decapsulate_result {
TunnResult::Done => false,
TunnResult::Done => Ok(false),
TunnResult::Err(e) => {
tracing::error!(error = ?e, "decapsulate_packet");
let _ = self.callbacks().on_error(&e.into());
false
Ok(false)
}
TunnResult::WriteToNetwork(packet) => {
let bytes = Bytes::copy_from_slice(packet);
peer.send_infallible(bytes, &self.callbacks).await;
true
Ok(true)
}
TunnResult::WriteToTunnelV4(packet, addr) => {
self.send_to_resource(device_io, peer, addr.into(), packet);
false
self.send_to_resource(device_io, peer, addr.into(), packet)?;
Ok(false)
}
TunnResult::WriteToTunnelV6(packet, addr) => {
self.send_to_resource(device_io, peer, addr.into(), packet);
false
self.send_to_resource(device_io, peer, addr.into(), packet)?;
Ok(false)
}
}
}
@@ -108,7 +108,7 @@ where
if self
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
.await
.await?
{
// Flush pending queue
while let TunnResult::WriteToNetwork(packet) = {
@@ -125,10 +125,16 @@ where
Ok(())
}
pub(crate) async fn peer_handler(self: &Arc<Self>, peer: Arc<Peer>, device_io: DeviceIo) {
async fn peer_handler(
self: &Arc<Self>,
peer: &Arc<Peer>,
device_io: DeviceIo,
) -> std::io::Result<()> {
let mut src_buf = [0u8; MAX_UDP_SIZE];
let mut dst_buf = [0u8; MAX_UDP_SIZE];
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
// TODO: Double check that this can only happen on closed channel
// I think it's possible to transmit a 0-byte message through the channel
// but we would never use that.
@@ -137,14 +143,38 @@ where
break;
}
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
let _ = self
.handle_peer_packet(&peer, &device_io, &src_buf[..size], &mut dst_buf)
.await;
if let Err(Error::Io(e)) = self
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
.await
{
return Err(e);
}
}
let peer_stats = peer.stats();
tracing::debug!(peer = ?peer_stats, "peer_stopped");
Ok(())
}
pub(crate) async fn start_peer_handler(self: &Arc<Self>, peer: Arc<Peer>) {
loop {
let Some(device) = self.device.read().await.clone() else {
let err = Error::NoIface;
tracing::error!(?err);
let _ = self.callbacks().on_disconnect(Some(&err));
break;
};
let device_io = device.io;
if let Err(err) = self.peer_handler(&peer, device_io).await {
if err.raw_os_error() != Some(9) {
tracing::error!(?err);
let _ = self.callbacks().on_error(&err.into());
break;
} else {
tracing::warn!("bad_file_descriptor");
}
}
}
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
self.stop_peer(peer.index, peer.conn_id).await;
}
}

View File

@@ -24,15 +24,17 @@ where
}
#[inline(always)]
fn send_packet(&self, device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) {
fn send_packet(
&self,
device_io: &DeviceIo,
packet: &mut [u8],
dst_addr: IpAddr,
) -> std::io::Result<()> {
match dst_addr {
IpAddr::V4(_) => {
self.write4_device_infallible(device_io, packet);
}
IpAddr::V6(_) => {
self.write6_device_infallible(device_io, packet);
}
}
IpAddr::V4(_) => device_io.write4(packet)?,
IpAddr::V6(_) => device_io.write6(packet)?,
};
Ok(())
}
#[inline(always)]
@@ -42,26 +44,20 @@ where
peer: &Arc<Peer>,
addr: IpAddr,
packet: &mut [u8],
) {
) -> Result<()> {
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
// and we just trust gateways.
// In gateways this should never happen.
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
self.send_packet(device_io, packet, addr);
return;
self.send_packet(device_io, packet, addr)?;
return Ok(());
};
match get_resource_addr_and_port(peer, &resource, &addr, &dst) {
Ok((dst_addr, _dst_port)) => {
self.update_packet(packet, dst_addr);
self.send_packet(device_io, packet, addr);
}
Err(e) => {
tracing::error!(err = ?e, "resource_parse");
let _ = self.callbacks().on_error(&e);
}
}
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
self.update_packet(packet, dst_addr);
self.send_packet(device_io, packet, addr)?;
Ok(())
}
pub(crate) fn send_to_resource(
@@ -70,11 +66,13 @@ where
peer: &Arc<Peer>,
addr: IpAddr,
packet: &mut [u8],
) {
) -> Result<()> {
if peer.is_allowed(addr) {
self.packet_allowed(device_io, peer, addr, packet);
self.packet_allowed(device_io, peer, addr, packet)?;
Ok(())
} else {
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
Ok(())
}
}
}
@@ -90,7 +88,6 @@ fn get_resource_addr_and_port(
dst: &IpAddr,
) -> Result<(IpAddr, Option<u16>)> {
match resource {
// Note: for now no translation is needed for the ip since we do a peer/connection per resource
ResourceDescription::Dns(r) => {
let mut address = r.address.split(':');
let Some(dst_addr) = address.next() else {