mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
connlib+android: enable fd replacement (#2235)
Should be easier to review commit by commit. The gist of this commit is: * `onAddRoute` on Android now takes an address+prefix as to minimize parsing * `onAddRoute` recreates the vpn service each time(TODO: is this too bad for performance?) * `on_add_route` and `onAddRoute` returns the new fd * on android after `on_add_route` we recreate `IfaceConfig` and `DeviceIo` and we store the new values * `peer_handler` now runs on a loop, where each time we fail a write with an error code 9(bad descriptor) we try to take the new `DeviceIo` * we keep an [`AbortHandle`](https://docs.rs/tokio/latest/tokio/task/struct.AbortHandle.html) from the `iface_handler` task, since closing the fd doesn't awake the `read` task for `AsyncFd`(I tried it, right now `close` is only called after dropping the fd) so we explicitly abort the task and start a new one with the new `device_io`. * in android `DeviceIo` has an atomic which tells if it's closed or open and we change it to closed after `on_add_route`, we use this as to never double-close the fd, instead we wait until it's dropped. This *might* affect performance on android since we use non-`Ordering::Relaxed` atomic operation each read/write but it won't affect perfromance in other platforms, furthermore I believe the performance gains if we remove this will be minimal. Fixes #2227 --------- Co-authored-by: Jamil <jamilbk@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
plugins {
|
||||
id("com.android.application")
|
||||
id("kotlin-android")
|
||||
id("kotlin-kapt")
|
||||
id("dagger.hilt.android.plugin")
|
||||
id("kotlin-parcelize")
|
||||
id("androidx.navigation.safeargs")
|
||||
@@ -9,6 +8,7 @@ plugins {
|
||||
id("com.google.gms.google-services")
|
||||
id("com.google.firebase.crashlytics")
|
||||
id("com.diffplug.spotless") version "6.22.0"
|
||||
id("kotlin-kapt")
|
||||
}
|
||||
|
||||
spotless {
|
||||
|
||||
@@ -20,6 +20,7 @@ import dev.firezone.android.core.domain.preference.GetConfigUseCase
|
||||
import dev.firezone.android.core.presentation.MainActivity
|
||||
import dev.firezone.android.tunnel.callback.ConnlibCallback
|
||||
import dev.firezone.android.tunnel.data.TunnelRepository
|
||||
import dev.firezone.android.tunnel.model.Cidr
|
||||
import dev.firezone.android.tunnel.model.Resource
|
||||
import dev.firezone.android.tunnel.model.Tunnel
|
||||
import dev.firezone.android.tunnel.model.TunnelConfig
|
||||
@@ -99,16 +100,26 @@ class TunnelService : VpnService() {
|
||||
return true
|
||||
}
|
||||
|
||||
override fun onAddRoute(cidrAddress: String) {
|
||||
Log.d(TAG, "onAddRoute: $cidrAddress")
|
||||
|
||||
tunnelRepository.addRoute(cidrAddress)
|
||||
override fun onAddRoute(
|
||||
addr: String,
|
||||
prefix: Int,
|
||||
): Int {
|
||||
Log.d(TAG, "onAddRoute: $addr/$prefix")
|
||||
val route = Cidr(addr, prefix)
|
||||
tunnelRepository.addRoute(route)
|
||||
val fd = buildVpnService().establish()?.detachFd() ?: -1
|
||||
protect(fd)
|
||||
return fd
|
||||
}
|
||||
|
||||
override fun onRemoveRoute(cidrAddress: String) {
|
||||
Log.d(TAG, "onRemoveRoute: $cidrAddress")
|
||||
override fun onRemoveRoute(
|
||||
addr: String,
|
||||
prefix: Int,
|
||||
) {
|
||||
Log.d(TAG, "onRemoveRoute: $addr/$prefix")
|
||||
val cidr = Cidr(addr, prefix)
|
||||
|
||||
tunnelRepository.removeRoute(cidrAddress)
|
||||
tunnelRepository.removeRoute(cidr)
|
||||
}
|
||||
|
||||
override fun getSystemDefaultResolvers(): String {
|
||||
@@ -231,16 +242,9 @@ class TunnelService : VpnService() {
|
||||
|
||||
addDnsServer(tunnel.config.dnsAddress)
|
||||
|
||||
/*tunnel.routes.forEach {
|
||||
addRoute(it, 32)
|
||||
}*/
|
||||
|
||||
// TODO: These are the staging Resources. Remove these in favor of the onUpdateResources callback.
|
||||
addRoute("100.100.111.1", 32)
|
||||
addRoute("172.31.82.179", 32)
|
||||
addRoute("172.31.83.10", 32)
|
||||
addRoute("172.31.92.238", 32)
|
||||
addRoute("172.31.93.123", 32)
|
||||
tunnel.routes.forEach {
|
||||
addRoute(it.address, it.prefix)
|
||||
}
|
||||
|
||||
setSession(SESSION_NAME)
|
||||
|
||||
|
||||
@@ -11,9 +11,15 @@ interface ConnlibCallback {
|
||||
|
||||
fun onTunnelReady(): Boolean
|
||||
|
||||
fun onAddRoute(cidrAddress: String)
|
||||
fun onAddRoute(
|
||||
cidrAddress: String,
|
||||
prefix: Int,
|
||||
): Int
|
||||
|
||||
fun onRemoveRoute(cidrAddress: String)
|
||||
fun onRemoveRoute(
|
||||
addr: String,
|
||||
prefix: Int,
|
||||
)
|
||||
|
||||
fun onUpdateResources(resourceListJSON: String)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
package dev.firezone.android.tunnel.data
|
||||
|
||||
import android.content.SharedPreferences
|
||||
import dev.firezone.android.tunnel.model.Cidr
|
||||
import dev.firezone.android.tunnel.model.Resource
|
||||
import dev.firezone.android.tunnel.model.Tunnel
|
||||
import dev.firezone.android.tunnel.model.TunnelConfig
|
||||
@@ -21,11 +22,11 @@ interface TunnelRepository {
|
||||
|
||||
fun getResources(): List<Resource>
|
||||
|
||||
fun addRoute(route: String)
|
||||
fun addRoute(route: Cidr)
|
||||
|
||||
fun removeRoute(route: String)
|
||||
fun removeRoute(route: Cidr)
|
||||
|
||||
fun getRoutes(): List<String>
|
||||
fun getRoutes(): List<Cidr>
|
||||
|
||||
fun clear()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import dev.firezone.android.tunnel.data.TunnelRepository.Companion.CONFIG_KEY
|
||||
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.RESOURCES_KEY
|
||||
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.ROUTES_KEY
|
||||
import dev.firezone.android.tunnel.data.TunnelRepository.Companion.STATE_KEY
|
||||
import dev.firezone.android.tunnel.model.Cidr
|
||||
import dev.firezone.android.tunnel.model.Resource
|
||||
import dev.firezone.android.tunnel.model.Tunnel
|
||||
import dev.firezone.android.tunnel.model.TunnelConfig
|
||||
@@ -83,30 +84,30 @@ class TunnelRepositoryImpl
|
||||
}
|
||||
}
|
||||
|
||||
override fun addRoute(route: String) {
|
||||
override fun addRoute(route: Cidr) {
|
||||
synchronized(lock) {
|
||||
getRoutes().toMutableList().run {
|
||||
add(route)
|
||||
val json = moshi.adapter<List<String>>().toJson(this)
|
||||
val json = moshi.adapter<List<Cidr>>().toJson(this)
|
||||
sharedPreferences.edit().putString(ROUTES_KEY, json).apply()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun removeRoute(route: String) {
|
||||
override fun removeRoute(route: Cidr) {
|
||||
synchronized(lock) {
|
||||
getRoutes().toMutableList().run {
|
||||
remove(route)
|
||||
val json = moshi.adapter<List<String>>().toJson(this)
|
||||
val json = moshi.adapter<List<Cidr>>().toJson(this)
|
||||
sharedPreferences.edit().putString(ROUTES_KEY, json).apply()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun getRoutes(): List<String> =
|
||||
override fun getRoutes(): List<Cidr> =
|
||||
synchronized(lock) {
|
||||
val json = sharedPreferences.getString(ROUTES_KEY, "[]") ?: "[]"
|
||||
return moshi.adapter<List<String>>().fromJson(json) ?: emptyList()
|
||||
return moshi.adapter<List<Cidr>>().fromJson(json) ?: emptyList()
|
||||
}
|
||||
|
||||
override fun clear() {
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
/* Licensed under Apache 2.0 (C) 2023 Firezone, Inc. */
|
||||
package dev.firezone.android.tunnel.model
|
||||
|
||||
import android.os.Parcelable
|
||||
import com.squareup.moshi.JsonClass
|
||||
import kotlinx.parcelize.Parcelize
|
||||
|
||||
@JsonClass(generateAdapter = true)
|
||||
@Parcelize
|
||||
data class Cidr(
|
||||
// TODO: Not convinced of using String to store address, we can make a moshi InetAddress adapter
|
||||
val address: String,
|
||||
val prefix: Int,
|
||||
) : Parcelable
|
||||
@@ -10,7 +10,7 @@ import kotlinx.parcelize.Parcelize
|
||||
data class Tunnel(
|
||||
val config: TunnelConfig = TunnelConfig(),
|
||||
var state: State = State.Down,
|
||||
val routes: List<String> = emptyList(),
|
||||
val routes: List<Cidr> = emptyList(),
|
||||
val resources: List<Resource> = emptyList(),
|
||||
) : Parcelable {
|
||||
enum class State {
|
||||
|
||||
@@ -122,8 +122,8 @@ fn init_logging(log_dir: &Path, log_filter: String) -> file_logger::Handle {
|
||||
.expect("Logging guard should never be initialized twice");
|
||||
|
||||
let _ = tracing_subscriber::registry()
|
||||
.with(file_layer.with_filter(EnvFilter::new(log_filter)))
|
||||
.with(android_layer())
|
||||
.with(file_layer.with_filter(EnvFilter::new(log_filter.clone())))
|
||||
.with(android_layer().with_filter(EnvFilter::new(log_filter.clone())))
|
||||
.try_init();
|
||||
|
||||
handle
|
||||
@@ -138,7 +138,7 @@ impl Callbacks for CallbackHandler {
|
||||
tunnel_address_v6: Ipv6Addr,
|
||||
dns_address: Ipv4Addr,
|
||||
dns_fallback_strategy: String,
|
||||
) -> Result<RawFd, Self::Error> {
|
||||
) -> Result<Option<RawFd>, Self::Error> {
|
||||
self.env(|mut env| {
|
||||
let tunnel_address_v4 =
|
||||
env.new_string(tunnel_address_v4.to_string())
|
||||
@@ -179,6 +179,7 @@ impl Callbacks for CallbackHandler {
|
||||
],
|
||||
)
|
||||
.and_then(|val| val.i())
|
||||
.map(Some)
|
||||
.map_err(|source| CallbackError::CallMethodFailed { name, source })
|
||||
})
|
||||
}
|
||||
@@ -195,27 +196,31 @@ impl Callbacks for CallbackHandler {
|
||||
})
|
||||
}
|
||||
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
|
||||
self.env(|mut env| {
|
||||
let route = env.new_string(route.to_string()).map_err(|source| {
|
||||
CallbackError::NewStringFailed {
|
||||
name: "route",
|
||||
let ip = env
|
||||
.new_string(route.network_address().to_string())
|
||||
.map_err(|source| CallbackError::NewStringFailed {
|
||||
name: "route_ip",
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
call_method(
|
||||
&mut env,
|
||||
})?;
|
||||
|
||||
let name = "onAddRoute";
|
||||
env.call_method(
|
||||
&self.callback_handler,
|
||||
"onAddRoute",
|
||||
"(Ljava/lang/String;)V",
|
||||
&[JValue::from(&route)],
|
||||
name,
|
||||
"(Ljava/lang/String;I)I",
|
||||
&[JValue::from(&ip), JValue::Int(route.netmask().into())],
|
||||
)
|
||||
.and_then(|val| val.i())
|
||||
.map(Some)
|
||||
.map_err(|source| CallbackError::CallMethodFailed { name, source })
|
||||
})
|
||||
}
|
||||
|
||||
fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
|
||||
self.env(|mut env| {
|
||||
let route = env.new_string(route.to_string()).map_err(|source| {
|
||||
let ip = env.new_string(route.to_string()).map_err(|source| {
|
||||
CallbackError::NewStringFailed {
|
||||
name: "route",
|
||||
source,
|
||||
@@ -225,8 +230,8 @@ impl Callbacks for CallbackHandler {
|
||||
&mut env,
|
||||
&self.callback_handler,
|
||||
"onRemoveRoute",
|
||||
"(Ljava/lang/String;)V",
|
||||
&[JValue::from(&route)],
|
||||
"(Ljava/lang/String;I)V",
|
||||
&[JValue::from(&ip), JValue::Int(route.netmask().into())],
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -301,7 +306,7 @@ impl Callbacks for CallbackHandler {
|
||||
fn throw(env: &mut JNIEnv, class: &str, msg: impl Into<JNIString>) {
|
||||
if let Err(err) = env.throw_new(class, msg) {
|
||||
// We can't panic, since unwinding across the FFI boundary is UB...
|
||||
tracing::error!("failed to throw Java exception: {err}");
|
||||
tracing::error!(?err, "failed to throw Java exception");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -89,14 +89,14 @@ impl Callbacks for CallbackHandler {
|
||||
tunnel_address_v6: Ipv6Addr,
|
||||
dns_address: Ipv4Addr,
|
||||
dns_fallback_strategy: String,
|
||||
) -> Result<RawFd, Self::Error> {
|
||||
) -> Result<Option<RawFd>, Self::Error> {
|
||||
self.inner.on_set_interface_config(
|
||||
tunnel_address_v4.to_string(),
|
||||
tunnel_address_v6.to_string(),
|
||||
dns_address.to_string(),
|
||||
dns_fallback_strategy.to_string(),
|
||||
);
|
||||
Ok(-1)
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn on_tunnel_ready(&self) -> Result<(), Self::Error> {
|
||||
@@ -104,9 +104,9 @@ impl Callbacks for CallbackHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
|
||||
self.inner.on_add_route(route.to_string());
|
||||
Ok(())
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn on_remove_route(&self, route: IpNetwork) -> Result<(), Self::Error> {
|
||||
|
||||
@@ -78,7 +78,11 @@ where
|
||||
runtime_stopper: tx.clone(),
|
||||
callbacks,
|
||||
};
|
||||
// In android we get an stack-overflow due to tokio
|
||||
// taking too much of the stack-space:
|
||||
// See: https://github.com/firezone/firezone/issues/2227
|
||||
let runtime = tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_stack_size(3 * 1024 * 1024)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
{
|
||||
|
||||
@@ -14,14 +14,17 @@ pub trait Callbacks: Clone + Send + Sync {
|
||||
type Error: Debug + Display + Error;
|
||||
|
||||
/// Called when the tunnel address is set.
|
||||
///
|
||||
/// This should return a new `fd` if there is one.
|
||||
/// (Only happens on android for now)
|
||||
fn on_set_interface_config(
|
||||
&self,
|
||||
_: Ipv4Addr,
|
||||
_: Ipv6Addr,
|
||||
_: Ipv4Addr,
|
||||
_: String,
|
||||
) -> Result<RawFd, Self::Error> {
|
||||
Ok(-1)
|
||||
) -> Result<Option<RawFd>, Self::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Called when the tunnel is connected.
|
||||
@@ -31,8 +34,11 @@ pub trait Callbacks: Clone + Send + Sync {
|
||||
}
|
||||
|
||||
/// Called when when a route is added.
|
||||
fn on_add_route(&self, _: IpNetwork) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
///
|
||||
/// This should return a new `fd` if there is one.
|
||||
/// (Only happens on android for now)
|
||||
fn on_add_route(&self, _: IpNetwork) -> Result<Option<RawFd>, Self::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Called when when a route is removed.
|
||||
|
||||
@@ -18,7 +18,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
tunnel_address_v6: Ipv6Addr,
|
||||
dns_address: Ipv4Addr,
|
||||
dns_fallback_strategy: String,
|
||||
) -> Result<RawFd> {
|
||||
) -> Result<Option<RawFd>> {
|
||||
let result = self
|
||||
.0
|
||||
.on_set_interface_config(
|
||||
@@ -29,7 +29,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
)
|
||||
.map_err(|err| Error::OnSetInterfaceConfigFailed(err.to_string()));
|
||||
if let Err(err) = result.as_ref() {
|
||||
tracing::error!("{err}");
|
||||
tracing::error!(?err);
|
||||
}
|
||||
result
|
||||
}
|
||||
@@ -40,18 +40,18 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
.on_tunnel_ready()
|
||||
.map_err(|err| Error::OnTunnelReadyFailed(err.to_string()));
|
||||
if let Err(err) = result.as_ref() {
|
||||
tracing::error!("{err}");
|
||||
tracing::error!(?err);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<()> {
|
||||
fn on_add_route(&self, route: IpNetwork) -> Result<Option<RawFd>> {
|
||||
let result = self
|
||||
.0
|
||||
.on_add_route(route)
|
||||
.map_err(|err| Error::OnAddRouteFailed(err.to_string()));
|
||||
if let Err(err) = result.as_ref() {
|
||||
tracing::error!("{err}");
|
||||
tracing::error!(?err);
|
||||
}
|
||||
result
|
||||
}
|
||||
@@ -62,7 +62,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
.on_remove_route(route)
|
||||
.map_err(|err| Error::OnRemoveRouteFailed(err.to_string()));
|
||||
if let Err(err) = result.as_ref() {
|
||||
tracing::error!("{err}");
|
||||
tracing::error!(?err);
|
||||
}
|
||||
result
|
||||
}
|
||||
@@ -73,14 +73,14 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
.on_update_resources(resource_list)
|
||||
.map_err(|err| Error::OnUpdateResourcesFailed(err.to_string()));
|
||||
if let Err(err) = result.as_ref() {
|
||||
tracing::error!("{err}");
|
||||
tracing::error!(?err);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn on_disconnect(&self, error: Option<&Error>) -> Result<()> {
|
||||
if let Err(err) = self.0.on_disconnect(error) {
|
||||
tracing::error!("`on_disconnect` failed: {err}");
|
||||
tracing::error!(?err, "`on_disconnect` failed");
|
||||
}
|
||||
// There's nothing we can really do if `on_disconnect` fails.
|
||||
Ok(())
|
||||
@@ -88,7 +88,7 @@ impl<CB: Callbacks> Callbacks for CallbackErrorFacade<CB> {
|
||||
|
||||
fn on_error(&self, error: &Error) -> Result<()> {
|
||||
if let Err(err) = self.0.on_error(error) {
|
||||
tracing::error!("`on_error` failed: {err}");
|
||||
tracing::error!(?err, "`on_error` failed");
|
||||
}
|
||||
// There's nothing we really want to do if `on_error` fails.
|
||||
Ok(())
|
||||
|
||||
@@ -96,6 +96,9 @@ pub enum ConnlibError {
|
||||
/// No iface found
|
||||
#[error("No iface found")]
|
||||
NoIface,
|
||||
/// Expected file descriptor and none was found
|
||||
#[error("No filedescriptor")]
|
||||
NoFd,
|
||||
/// No MTU found
|
||||
#[error("No MTU found")]
|
||||
NoMtu,
|
||||
@@ -120,6 +123,9 @@ pub enum ConnlibError {
|
||||
/// Invalid source address for peer
|
||||
#[error("Invalid source address")]
|
||||
InvalidSource,
|
||||
/// Any parse error
|
||||
#[error("parse error")]
|
||||
ParseError,
|
||||
}
|
||||
|
||||
impl ConnlibError {
|
||||
|
||||
@@ -145,12 +145,8 @@ where
|
||||
})
|
||||
});
|
||||
|
||||
let Some(device_io) = self.device_io.read().clone() else {
|
||||
return Err(Error::NoIface);
|
||||
};
|
||||
|
||||
let tunnel = Arc::clone(self);
|
||||
tokio::spawn(async move { tunnel.peer_handler(peer, device_io).await });
|
||||
tokio::spawn(async move { tunnel.start_peer_handler(peer).await });
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -96,12 +96,13 @@ where
|
||||
tracing::trace!("new_data_channel_open");
|
||||
Box::pin(async move {
|
||||
{
|
||||
let Some(iface_config) = tunnel.iface_config.read().clone() else {
|
||||
let Some(device) = tunnel.device.read().await.clone() else {
|
||||
let e = Error::NoIface;
|
||||
tracing::error!(err = ?e, "channel_open");
|
||||
let _ = tunnel.callbacks().on_error(&e);
|
||||
return;
|
||||
};
|
||||
let iface_config = device.config;
|
||||
for &ip in &peer.ips {
|
||||
if let Err(e) = iface_config.add_route(ip, tunnel.callbacks()).await
|
||||
{
|
||||
|
||||
@@ -9,6 +9,8 @@ use tokio::io::{unix::AsyncFd, Interest};
|
||||
|
||||
use tun::{IfaceDevice, IfaceStream};
|
||||
|
||||
use crate::Device;
|
||||
|
||||
mod tun;
|
||||
|
||||
pub(crate) struct IfaceConfig {
|
||||
@@ -53,23 +55,32 @@ impl IfaceConfig {
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<()> {
|
||||
self.iface.add_route(route, callbacks).await
|
||||
) -> Result<Option<Device>> {
|
||||
let Some((iface, stream)) = self.iface.add_route(route, callbacks).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let io = DeviceIo(stream);
|
||||
let mtu = iface.mtu().await?;
|
||||
let config = Arc::new(IfaceConfig {
|
||||
iface,
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
});
|
||||
Ok(Some(Device { io, config }))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn create_iface(
|
||||
config: &Interface,
|
||||
callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<(IfaceConfig, DeviceIo)> {
|
||||
) -> Result<Device> {
|
||||
let (iface, stream) = IfaceDevice::new(config, callbacks).await?;
|
||||
iface.up().await?;
|
||||
let device_io = DeviceIo(stream);
|
||||
let io = DeviceIo(stream);
|
||||
let mtu = iface.mtu().await?;
|
||||
let iface_config = IfaceConfig {
|
||||
let config = Arc::new(IfaceConfig {
|
||||
iface,
|
||||
mtu: AtomicUsize::new(mtu),
|
||||
};
|
||||
});
|
||||
|
||||
Ok((iface_config, device_io))
|
||||
Ok(Device { config, io })
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::Device;
|
||||
use connlib_shared::{messages::Interface, CallbackErrorFacade, Callbacks, Result};
|
||||
use ip_network::IpNetwork;
|
||||
|
||||
@@ -33,7 +34,7 @@ impl IfaceConfig {
|
||||
&self,
|
||||
_: IpNetwork,
|
||||
_: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<Device>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
@@ -41,6 +42,6 @@ impl IfaceConfig {
|
||||
pub(crate) async fn create_iface(
|
||||
_: &Interface,
|
||||
_: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<(IfaceConfig, DeviceIo)> {
|
||||
) -> Result<Device> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
35
rust/connlib/tunnel/src/device_channel/tun/closeable.rs
Normal file
35
rust/connlib/tunnel/src/device_channel/tun/closeable.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use std::os::fd::{AsRawFd, RawFd};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Closeable {
|
||||
closed: AtomicBool,
|
||||
value: RawFd,
|
||||
}
|
||||
|
||||
impl AsRawFd for Closeable {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.value.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl Closeable {
|
||||
pub(crate) fn new(fd: RawFd) -> Self {
|
||||
Self {
|
||||
closed: AtomicBool::new(false),
|
||||
value: fd,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with<U>(&self, f: impl FnOnce(RawFd) -> U) -> std::io::Result<U> {
|
||||
if self.closed.load(Ordering::Acquire) {
|
||||
return Err(std::io::Error::from_raw_os_error(9));
|
||||
}
|
||||
|
||||
Ok(f(self.value))
|
||||
}
|
||||
|
||||
pub(crate) fn close(&self) {
|
||||
self.closed.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::InterfaceConfig;
|
||||
use closeable::Closeable;
|
||||
use connlib_shared::{CallbackErrorFacade, Callbacks, Error, Result, DNS_SENTINEL};
|
||||
use ip_network::IpNetwork;
|
||||
use libc::{
|
||||
@@ -13,6 +14,7 @@ use std::{
|
||||
};
|
||||
use tokio::io::unix::AsyncFd;
|
||||
|
||||
mod closeable;
|
||||
mod wrapped_socket;
|
||||
// Android doesn't support Split DNS. So we intercept all requests and forward
|
||||
// the non-Firezone name resolution requests to the upstream DNS resolver.
|
||||
@@ -50,24 +52,27 @@ pub(crate) struct IfaceDevice(Arc<AsyncFd<IfaceStream>>);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct IfaceStream {
|
||||
fd: RawFd,
|
||||
fd: Closeable,
|
||||
}
|
||||
|
||||
impl AsRawFd for IfaceStream {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.fd
|
||||
self.fd.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for IfaceStream {
|
||||
fn drop(&mut self) {
|
||||
unsafe { close(self.fd) };
|
||||
unsafe { close(self.fd.as_raw_fd()) };
|
||||
}
|
||||
}
|
||||
|
||||
impl IfaceStream {
|
||||
fn write(&self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match unsafe { write(self.fd, buf.as_ptr() as _, buf.len() as _) } {
|
||||
match self
|
||||
.fd
|
||||
.with(|fd| unsafe { write(fd, buf.as_ptr() as _, buf.len() as _) })?
|
||||
{
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
n => Ok(n as usize),
|
||||
}
|
||||
@@ -81,12 +86,21 @@ impl IfaceStream {
|
||||
self.write(src)
|
||||
}
|
||||
|
||||
pub fn read<'a>(&self, dst: &'a mut [u8]) -> std::io::Result<usize> {
|
||||
match unsafe { read(self.fd, dst.as_mut_ptr() as _, dst.len()) } {
|
||||
pub fn read(&self, dst: &mut [u8]) -> std::io::Result<usize> {
|
||||
// We don't read(or write) again from the fd because the given fd number might be reclaimed
|
||||
// so this could make an spurious read/write to another fd and we DEFINITELY don't want that
|
||||
match self
|
||||
.fd
|
||||
.with(|fd| unsafe { read(fd, dst.as_mut_ptr() as _, dst.len()) })?
|
||||
{
|
||||
-1 => Err(io::Error::last_os_error()),
|
||||
n => Ok(n as usize),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
self.fd.close();
|
||||
}
|
||||
}
|
||||
|
||||
impl IfaceDevice {
|
||||
@@ -94,13 +108,17 @@ impl IfaceDevice {
|
||||
config: &InterfaceConfig,
|
||||
callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<(Self, Arc<AsyncFd<IfaceStream>>)> {
|
||||
let fd = callbacks.on_set_interface_config(
|
||||
config.ipv4,
|
||||
config.ipv6,
|
||||
DNS_SENTINEL,
|
||||
DNS_FALLBACK_STRATEGY.to_string(),
|
||||
)?;
|
||||
let iface_stream = Arc::new(AsyncFd::new(IfaceStream { fd: fd.into() })?);
|
||||
let fd = callbacks
|
||||
.on_set_interface_config(
|
||||
config.ipv4,
|
||||
config.ipv6,
|
||||
DNS_SENTINEL,
|
||||
DNS_FALLBACK_STRATEGY.to_string(),
|
||||
)?
|
||||
.ok_or(Error::NoFd)?;
|
||||
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
|
||||
fd: Closeable::new(fd.into()),
|
||||
})?);
|
||||
let this = Self(Arc::clone(&iface_stream));
|
||||
|
||||
Ok((this, iface_stream))
|
||||
@@ -112,7 +130,12 @@ impl IfaceDevice {
|
||||
ifr_ifru: unsafe { std::mem::zeroed() },
|
||||
};
|
||||
|
||||
match unsafe { ioctl(self.0.get_ref().fd, TUNGETIFF as _, &mut ifr) } {
|
||||
match self
|
||||
.0
|
||||
.get_ref()
|
||||
.fd
|
||||
.with(|fd| unsafe { ioctl(fd, TUNGETIFF as _, &mut ifr) })?
|
||||
{
|
||||
0 => {
|
||||
let name_cstr = unsafe { std::ffi::CStr::from_ptr(ifr.ifr_name.as_ptr() as _) };
|
||||
Ok(name_cstr.to_string_lossy().into_owned())
|
||||
@@ -151,8 +174,15 @@ impl IfaceDevice {
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<()> {
|
||||
callbacks.on_add_route(route)
|
||||
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
|
||||
self.0.get_ref().close();
|
||||
let fd = callbacks.on_add_route(route)?.ok_or(Error::NoFd)?;
|
||||
let iface_stream = Arc::new(AsyncFd::new(IfaceStream {
|
||||
fd: Closeable::new(fd.into()),
|
||||
})?);
|
||||
let this = Self(Arc::clone(&iface_stream));
|
||||
|
||||
Ok(Some((this, iface_stream)))
|
||||
}
|
||||
|
||||
pub async fn up(&self) -> Result<()> {
|
||||
|
||||
@@ -195,12 +195,12 @@ impl IfaceDevice {
|
||||
}
|
||||
|
||||
if addr.sc_id == info.ctl_id {
|
||||
let _ = callbacks.on_set_interface_config(
|
||||
callbacks.on_set_interface_config(
|
||||
config.ipv4,
|
||||
config.ipv6,
|
||||
DNS_SENTINEL,
|
||||
"system_resolver".to_string(),
|
||||
);
|
||||
)?;
|
||||
|
||||
set_non_blocking(fd)?;
|
||||
|
||||
@@ -241,8 +241,10 @@ impl IfaceDevice {
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<()> {
|
||||
callbacks.on_add_route(route)
|
||||
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
|
||||
// This will always be None in macos
|
||||
callbacks.on_add_route(route)?;
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn up(&self) -> Result<()> {
|
||||
|
||||
@@ -175,7 +175,7 @@ impl IfaceDevice {
|
||||
&self,
|
||||
route: IpNetwork,
|
||||
_callbacks: &CallbackErrorFacade<impl Callbacks>,
|
||||
) -> Result<()> {
|
||||
) -> Result<Option<(Self, Arc<AsyncFd<IfaceStream>>)>> {
|
||||
let req = self
|
||||
.handle
|
||||
.route()
|
||||
@@ -211,7 +211,7 @@ impl IfaceDevice {
|
||||
}
|
||||
*/
|
||||
|
||||
Ok(())
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self, _callbacks))]
|
||||
|
||||
@@ -138,9 +138,9 @@ where
|
||||
) -> Result<()> {
|
||||
if let Some(r) = self.check_for_dns(src) {
|
||||
match r {
|
||||
dns::SendPacket::Ipv4(r) => self.write4_device_infallible(device_writer, &r[..]),
|
||||
dns::SendPacket::Ipv6(r) => self.write6_device_infallible(device_writer, &r[..]),
|
||||
}
|
||||
dns::SendPacket::Ipv4(r) => device_writer.write4(&r[..])?,
|
||||
dns::SendPacket::Ipv6(r) => device_writer.write6(&r[..])?,
|
||||
};
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -170,29 +170,30 @@ where
|
||||
device_io: DeviceIo,
|
||||
) {
|
||||
let device_writer = device_io.clone();
|
||||
let mut src = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst = [0u8; MAX_UDP_SIZE];
|
||||
loop {
|
||||
let mut src = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst = [0u8; MAX_UDP_SIZE];
|
||||
let res = {
|
||||
// TODO: We should check here if what we read is a whole packet
|
||||
// there's no docs on tun device on when a whole packet is read, is it \n or another thing?
|
||||
// found some comments saying that a single read syscall represents a single packet but no docs on that
|
||||
// See https://stackoverflow.com/questions/18461365/how-to-read-packet-by-packet-from-linux-tun-tap
|
||||
match device_io.read(&mut src[..iface_config.mtu()]).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!(error = ?e, from = "iface", action = "read");
|
||||
let _ = self.callbacks.on_error(&e.into());
|
||||
continue;
|
||||
}
|
||||
let res = match device_io.read(&mut src[..iface_config.mtu()]).await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!(err = ?e, "failed to read interface: {e:#}");
|
||||
let _ = self.callbacks.on_error(&e.into());
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::trace!(target: "wire", action = "read", bytes = res, from = "iface");
|
||||
// TODO
|
||||
let _ = self
|
||||
|
||||
if res == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(e) = self
|
||||
.handle_iface_packet(&device_writer, &mut src[..res], &mut dst)
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
let _ = self.callbacks.on_error(&e);
|
||||
tracing::error!(err = ?e, "failed to handle packet {e:#}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use itertools::Itertools;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use peer::{Peer, PeerStats};
|
||||
use resource_table::ResourceTable;
|
||||
use tokio::time::MissedTickBehavior;
|
||||
use tokio::{task::AbortHandle, time::MissedTickBehavior};
|
||||
use webrtc::{
|
||||
api::{
|
||||
interceptor_registry::register_default_interceptors, media_engine::MediaEngine,
|
||||
@@ -140,15 +140,19 @@ struct AwaitingConnectionDetails {
|
||||
pub response_received: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Device {
|
||||
pub config: Arc<IfaceConfig>,
|
||||
pub io: DeviceIo,
|
||||
}
|
||||
|
||||
// TODO: We should use newtypes for each kind of Id
|
||||
/// Tunnel is a wireguard state machine that uses webrtc's ICE channels instead of UDP sockets
|
||||
/// to communicate between peers.
|
||||
pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
|
||||
next_index: Mutex<IndexLfsr>,
|
||||
// We use a tokio's mutex here since it makes things easier and we only need it
|
||||
// during init, so the performance hit is neglibile
|
||||
iface_config: RwLock<Option<Arc<IfaceConfig>>>,
|
||||
device_io: RwLock<Option<DeviceIo>>,
|
||||
// We use a tokio Mutex here since this is only read/write during config so there's no relevant performance impact
|
||||
device: tokio::sync::RwLock<Option<Device>>,
|
||||
rate_limiter: Arc<RateLimiter>,
|
||||
private_key: StaticSecret,
|
||||
public_key: PublicKey,
|
||||
@@ -164,6 +168,7 @@ pub struct Tunnel<C: ControlSignal, CB: Callbacks> {
|
||||
control_signaler: C,
|
||||
gateway_public_keys: Mutex<HashMap<GatewayId, PublicKey>>,
|
||||
callbacks: CallbackErrorFacade<CB>,
|
||||
iface_handler_abort: Mutex<Option<AbortHandle>>,
|
||||
}
|
||||
|
||||
// TODO: For now we only use these fields with debug
|
||||
@@ -249,9 +254,9 @@ where
|
||||
let gateway_public_keys = Default::default();
|
||||
let resources_gateways = Default::default();
|
||||
let gateway_awaiting_connection = Default::default();
|
||||
let iface_config = Default::default();
|
||||
let device_io = Default::default();
|
||||
let device = Default::default();
|
||||
let ice_candidate_queue = Default::default();
|
||||
let iface_handler_abort = Default::default();
|
||||
|
||||
// ICE
|
||||
let mut media_engine = MediaEngine::default();
|
||||
@@ -287,31 +292,54 @@ where
|
||||
next_index,
|
||||
webrtc_api,
|
||||
resources,
|
||||
iface_config,
|
||||
device_io,
|
||||
device,
|
||||
awaiting_connection,
|
||||
gateway_awaiting_connection,
|
||||
control_signaler,
|
||||
resources_gateways,
|
||||
ice_candidate_queue,
|
||||
callbacks: CallbackErrorFacade(callbacks),
|
||||
iface_handler_abort,
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_route(self: &Arc<Self>, route: IpNetwork) -> Result<()> {
|
||||
let mut device = self.device.write().await;
|
||||
|
||||
if let Some(new_device) = device
|
||||
.as_ref()
|
||||
.ok_or(Error::ControlProtocolError)?
|
||||
.config
|
||||
.add_route(route, self.callbacks())
|
||||
.await?
|
||||
{
|
||||
*device = Some(new_device.clone());
|
||||
let dev = Arc::clone(self);
|
||||
self.iface_handler_abort.lock().replace(
|
||||
tokio::spawn(
|
||||
async move { dev.iface_handler(new_device.config, new_device.io).await },
|
||||
)
|
||||
.abort_handle(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a the given resource to the tunnel.
|
||||
///
|
||||
/// Once added, when a packet for the resource is intercepted a new data channel will be created
|
||||
/// and packets will be wrapped with wireguard and sent through it.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn add_resource(&self, resource_description: ResourceDescription) -> Result<()> {
|
||||
pub async fn add_resource(
|
||||
self: &Arc<Self>,
|
||||
resource_description: ResourceDescription,
|
||||
) -> Result<()> {
|
||||
let mut any_valid_route = false;
|
||||
{
|
||||
let Some(iface_config) = self.iface_config.read().clone() else {
|
||||
tracing::error!("add_resource_before_initialization");
|
||||
return Err(Error::ControlProtocolError);
|
||||
};
|
||||
for ip in resource_description.ips() {
|
||||
if let Err(e) = iface_config.add_route(ip, self.callbacks()).await {
|
||||
if let Err(e) = self.add_route(ip).await {
|
||||
tracing::warn!(route = %ip, error = ?e, "add_route");
|
||||
let _ = self.callbacks().on_error(&e);
|
||||
} else {
|
||||
@@ -336,17 +364,17 @@ where
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub async fn set_interface(self: &Arc<Self>, config: &InterfaceConfig) -> Result<()> {
|
||||
let (iface_config, device_io) = create_iface(config, self.callbacks()).await?;
|
||||
iface_config
|
||||
.add_route(DNS_SENTINEL.into(), self.callbacks())
|
||||
.await?;
|
||||
let iface_config = Arc::new(iface_config);
|
||||
let device = create_iface(config, self.callbacks()).await?;
|
||||
*self.device.write().await = Some(device.clone());
|
||||
|
||||
*self.device_io.write() = Some(device_io.clone());
|
||||
*self.iface_config.write() = Some(Arc::clone(&iface_config));
|
||||
self.start_timers()?;
|
||||
self.start_timers().await?;
|
||||
let dev = Arc::clone(self);
|
||||
tokio::spawn(async move { dev.iface_handler(iface_config, device_io).await });
|
||||
*self.iface_handler_abort.lock() = Some(
|
||||
tokio::spawn(async move { dev.iface_handler(device.config, device.io).await })
|
||||
.abort_handle(),
|
||||
);
|
||||
|
||||
self.add_route(DNS_SENTINEL.into()).await?;
|
||||
|
||||
self.callbacks.on_tunnel_ready()?;
|
||||
|
||||
@@ -453,17 +481,22 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
|
||||
let Some(iface_config) = self.iface_config.read().clone() else {
|
||||
return Err(Error::NoIface);
|
||||
};
|
||||
async fn start_refresh_mtu_timer(self: &Arc<Self>) -> Result<()> {
|
||||
let dev = self.clone();
|
||||
let callbacks = self.callbacks().clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(REFRESH_MTU_INTERVAL);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = iface_config.refresh_mtu().await {
|
||||
|
||||
let Some(device) = dev.device.read().await.clone() else {
|
||||
let err = Error::ControlProtocolError;
|
||||
tracing::error!(?err, "get_iface_config");
|
||||
let _ = callbacks.0.on_error(&err);
|
||||
continue;
|
||||
};
|
||||
if let Err(e) = device.config.refresh_mtu().await {
|
||||
tracing::error!(error = ?e, "refresh_mtu");
|
||||
let _ = callbacks.0.on_error(&e);
|
||||
}
|
||||
@@ -473,29 +506,13 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn start_timers(self: &Arc<Self>) -> Result<()> {
|
||||
self.start_refresh_mtu_timer()?;
|
||||
async fn start_timers(self: &Arc<Self>) -> Result<()> {
|
||||
self.start_refresh_mtu_timer().await?;
|
||||
self.start_rate_limiter_refresh_timer();
|
||||
self.start_peers_refresh_timer();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn write4_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) {
|
||||
if let Err(e) = device_io.write4(packet) {
|
||||
tracing::error!(?e, "iface_write");
|
||||
let _ = self.callbacks().on_error(&e.into());
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn write6_device_infallible(&self, device_io: &DeviceIo, packet: &[u8]) {
|
||||
if let Err(e) = device_io.write6(packet) {
|
||||
tracing::error!(?e, "iface_write");
|
||||
let _ = self.callbacks().on_error(&e.into());
|
||||
}
|
||||
}
|
||||
|
||||
fn get_resource(&self, buff: &[u8]) -> Option<ResourceDescription> {
|
||||
let addr = Tunn::dst_address(buff)?;
|
||||
let resources = self.resources.read();
|
||||
|
||||
@@ -66,26 +66,26 @@ where
|
||||
peer: &Arc<Peer>,
|
||||
device_io: &DeviceIo,
|
||||
decapsulate_result: TunnResult<'a>,
|
||||
) -> bool {
|
||||
) -> Result<bool> {
|
||||
match decapsulate_result {
|
||||
TunnResult::Done => false,
|
||||
TunnResult::Done => Ok(false),
|
||||
TunnResult::Err(e) => {
|
||||
tracing::error!(error = ?e, "decapsulate_packet");
|
||||
let _ = self.callbacks().on_error(&e.into());
|
||||
false
|
||||
Ok(false)
|
||||
}
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
let bytes = Bytes::copy_from_slice(packet);
|
||||
peer.send_infallible(bytes, &self.callbacks).await;
|
||||
true
|
||||
Ok(true)
|
||||
}
|
||||
TunnResult::WriteToTunnelV4(packet, addr) => {
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet);
|
||||
false
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
Ok(false)
|
||||
}
|
||||
TunnResult::WriteToTunnelV6(packet, addr) => {
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet);
|
||||
false
|
||||
self.send_to_resource(device_io, peer, addr.into(), packet)?;
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -108,7 +108,7 @@ where
|
||||
|
||||
if self
|
||||
.handle_decapsulated_packet(peer, device_writer, decapsulate_result)
|
||||
.await
|
||||
.await?
|
||||
{
|
||||
// Flush pending queue
|
||||
while let TunnResult::WriteToNetwork(packet) = {
|
||||
@@ -125,10 +125,16 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn peer_handler(self: &Arc<Self>, peer: Arc<Peer>, device_io: DeviceIo) {
|
||||
async fn peer_handler(
|
||||
self: &Arc<Self>,
|
||||
peer: &Arc<Peer>,
|
||||
device_io: DeviceIo,
|
||||
) -> std::io::Result<()> {
|
||||
let mut src_buf = [0u8; MAX_UDP_SIZE];
|
||||
let mut dst_buf = [0u8; MAX_UDP_SIZE];
|
||||
while let Ok(size) = peer.channel.read(&mut src_buf[..]).await {
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
|
||||
// TODO: Double check that this can only happen on closed channel
|
||||
// I think it's possible to transmit a 0-byte message through the channel
|
||||
// but we would never use that.
|
||||
@@ -137,14 +143,38 @@ where
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::trace!(target: "wire", action = "read", bytes = size, from = "peer");
|
||||
let _ = self
|
||||
.handle_peer_packet(&peer, &device_io, &src_buf[..size], &mut dst_buf)
|
||||
.await;
|
||||
if let Err(Error::Io(e)) = self
|
||||
.handle_peer_packet(peer, &device_io, &src_buf[..size], &mut dst_buf)
|
||||
.await
|
||||
{
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
let peer_stats = peer.stats();
|
||||
tracing::debug!(peer = ?peer_stats, "peer_stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn start_peer_handler(self: &Arc<Self>, peer: Arc<Peer>) {
|
||||
loop {
|
||||
let Some(device) = self.device.read().await.clone() else {
|
||||
let err = Error::NoIface;
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_disconnect(Some(&err));
|
||||
break;
|
||||
};
|
||||
let device_io = device.io;
|
||||
|
||||
if let Err(err) = self.peer_handler(&peer, device_io).await {
|
||||
if err.raw_os_error() != Some(9) {
|
||||
tracing::error!(?err);
|
||||
let _ = self.callbacks().on_error(&err.into());
|
||||
break;
|
||||
} else {
|
||||
tracing::warn!("bad_file_descriptor");
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(peer = ?peer.stats(), "peer_stopped");
|
||||
self.stop_peer(peer.index, peer.conn_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,15 +24,17 @@ where
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn send_packet(&self, device_io: &DeviceIo, packet: &mut [u8], dst_addr: IpAddr) {
|
||||
fn send_packet(
|
||||
&self,
|
||||
device_io: &DeviceIo,
|
||||
packet: &mut [u8],
|
||||
dst_addr: IpAddr,
|
||||
) -> std::io::Result<()> {
|
||||
match dst_addr {
|
||||
IpAddr::V4(_) => {
|
||||
self.write4_device_infallible(device_io, packet);
|
||||
}
|
||||
IpAddr::V6(_) => {
|
||||
self.write6_device_infallible(device_io, packet);
|
||||
}
|
||||
}
|
||||
IpAddr::V4(_) => device_io.write4(packet)?,
|
||||
IpAddr::V6(_) => device_io.write6(packet)?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -42,26 +44,20 @@ where
|
||||
peer: &Arc<Peer>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) {
|
||||
) -> Result<()> {
|
||||
let Some((dst, resource)) = peer.get_packet_resource(packet) else {
|
||||
// If there's no associated resource it means that we are in a client, then the packet comes from a gateway
|
||||
// and we just trust gateways.
|
||||
// In gateways this should never happen.
|
||||
tracing::trace!(target: "wire", action = "writing", to = "iface", %addr, bytes = %packet.len());
|
||||
self.send_packet(device_io, packet, addr);
|
||||
return;
|
||||
self.send_packet(device_io, packet, addr)?;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
match get_resource_addr_and_port(peer, &resource, &addr, &dst) {
|
||||
Ok((dst_addr, _dst_port)) => {
|
||||
self.update_packet(packet, dst_addr);
|
||||
self.send_packet(device_io, packet, addr);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(err = ?e, "resource_parse");
|
||||
let _ = self.callbacks().on_error(&e);
|
||||
}
|
||||
}
|
||||
let (dst_addr, _dst_port) = get_resource_addr_and_port(peer, &resource, &addr, &dst)?;
|
||||
self.update_packet(packet, dst_addr);
|
||||
self.send_packet(device_io, packet, addr)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn send_to_resource(
|
||||
@@ -70,11 +66,13 @@ where
|
||||
peer: &Arc<Peer>,
|
||||
addr: IpAddr,
|
||||
packet: &mut [u8],
|
||||
) {
|
||||
) -> Result<()> {
|
||||
if peer.is_allowed(addr) {
|
||||
self.packet_allowed(device_io, peer, addr, packet);
|
||||
self.packet_allowed(device_io, peer, addr, packet)?;
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::warn!(%addr, "Received packet from peer with an unallowed ip");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +88,6 @@ fn get_resource_addr_and_port(
|
||||
dst: &IpAddr,
|
||||
) -> Result<(IpAddr, Option<u16>)> {
|
||||
match resource {
|
||||
// Note: for now no translation is needed for the ip since we do a peer/connection per resource
|
||||
ResourceDescription::Dns(r) => {
|
||||
let mut address = r.address.split(':');
|
||||
let Some(dst_addr) = address.next() else {
|
||||
|
||||
Reference in New Issue
Block a user