mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 10:18:54 +00:00
refactor(connlib): implement new FFI guidelines (#4263)
This updates connlib to follow the new guidelines described in #4262. I only made the bare-minimum changes to the clients. With these changes `reconnect` should only be called when the network interface actually changed, meaning clients have to be updated to reflect that.
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
use connlib_client_shared::{
|
||||
file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError,
|
||||
ResourceDescription, Session,
|
||||
ResourceDescription, Session, Sockets,
|
||||
};
|
||||
use jni::{
|
||||
objects::{GlobalRef, JClass, JObject, JString, JValue},
|
||||
@@ -67,6 +67,8 @@ pub enum CallbackError {
|
||||
name: &'static str,
|
||||
source: jni::errors::Error,
|
||||
},
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl CallbackHandler {
|
||||
@@ -79,6 +81,18 @@ impl CallbackHandler {
|
||||
.map_err(CallbackError::AttachCurrentThreadFailed)
|
||||
.and_then(f)
|
||||
}
|
||||
|
||||
fn protect_file_descriptor(&self, file_descriptor: RawFd) -> Result<(), CallbackError> {
|
||||
self.env(|mut env| {
|
||||
call_method(
|
||||
&mut env,
|
||||
&self.callback_handler,
|
||||
"protectFileDescriptor",
|
||||
"(I)V",
|
||||
&[JValue::Int(file_descriptor)],
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn call_method(
|
||||
@@ -227,20 +241,6 @@ impl Callbacks for CallbackHandler {
|
||||
.expect("onUpdateRoutes callback failed")
|
||||
}
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
fn protect_file_descriptor(&self, file_descriptor: RawFd) {
|
||||
self.env(|mut env| {
|
||||
call_method(
|
||||
&mut env,
|
||||
&self.callback_handler,
|
||||
"protectFileDescriptor",
|
||||
"(I)V",
|
||||
&[JValue::Int(file_descriptor)],
|
||||
)
|
||||
})
|
||||
.expect("protectFileDescriptor callback failed");
|
||||
}
|
||||
|
||||
fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) {
|
||||
self.env(|mut env| {
|
||||
let resource_list = env
|
||||
@@ -326,6 +326,8 @@ enum ConnectError {
|
||||
InvalidLoginUrl(#[from] LoginUrlError<url::ParseError>),
|
||||
#[error("Unable to create tokio runtime: {0}")]
|
||||
UnableToCreateRuntime(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
CallbackError(#[from] CallbackError),
|
||||
}
|
||||
|
||||
macro_rules! string_from_jstring {
|
||||
@@ -386,17 +388,28 @@ fn connect(
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let sockets = Sockets::new()?;
|
||||
|
||||
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
|
||||
callback_handler.protect_file_descriptor(ip4_socket)?;
|
||||
}
|
||||
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
|
||||
callback_handler.protect_file_descriptor(ip6_socket)?;
|
||||
}
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
sockets,
|
||||
private_key,
|
||||
Some(os_version),
|
||||
callback_handler,
|
||||
callback_handler.clone(),
|
||||
Some(MAX_PARTITION_TIME),
|
||||
runtime.handle().clone(),
|
||||
)?;
|
||||
|
||||
Ok(SessionWrapper {
|
||||
inner: session,
|
||||
callbacks: callback_handler,
|
||||
runtime,
|
||||
})
|
||||
}
|
||||
@@ -450,11 +463,29 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co
|
||||
|
||||
pub struct SessionWrapper {
|
||||
inner: Session,
|
||||
callbacks: CallbackHandler,
|
||||
|
||||
#[allow(dead_code)] // Only here so we don't drop the memory early.
|
||||
runtime: Runtime,
|
||||
}
|
||||
|
||||
impl SessionWrapper {
|
||||
fn reconnect(&self) -> Result<(), CallbackError> {
|
||||
let sockets = Sockets::new()?;
|
||||
|
||||
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
|
||||
self.callbacks.protect_file_descriptor(ip4_socket)?;
|
||||
}
|
||||
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
|
||||
self.callbacks.protect_file_descriptor(ip6_socket)?;
|
||||
}
|
||||
|
||||
self.inner.reconnect(sockets);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Pointers must be valid
|
||||
#[allow(non_snake_case)]
|
||||
@@ -497,9 +528,11 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
|
||||
#[allow(non_snake_case)]
|
||||
#[no_mangle]
|
||||
pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_reconnect(
|
||||
_: JNIEnv,
|
||||
mut env: JNIEnv,
|
||||
_: JClass,
|
||||
session: *const SessionWrapper,
|
||||
) {
|
||||
(*session).inner.reconnect();
|
||||
if let Err(e) = (*session).reconnect() {
|
||||
throw(&mut env, "java/lang/Exception", e.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use connlib_client_shared::{
|
||||
file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, ResourceDescription, Session,
|
||||
Sockets,
|
||||
};
|
||||
use secrecy::SecretString;
|
||||
use std::{
|
||||
@@ -204,6 +205,7 @@ impl WrappedSession {
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
Sockets::new().map_err(|err| err.to_string())?,
|
||||
private_key,
|
||||
os_version_override,
|
||||
CallbackHandler {
|
||||
|
||||
@@ -10,7 +10,7 @@ use connlib_shared::{
|
||||
messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId},
|
||||
Callbacks,
|
||||
};
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use firezone_tunnel::{ClientTunnel, Sockets};
|
||||
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -18,7 +18,7 @@ use std::{
|
||||
net::IpAddr,
|
||||
path::PathBuf,
|
||||
task::{Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time::{Interval, MissedTickBehavior};
|
||||
use url::Url;
|
||||
@@ -37,7 +37,7 @@ pub struct Eventloop<C: Callbacks> {
|
||||
/// Commands that can be sent to the [`Eventloop`].
|
||||
pub enum Command {
|
||||
Stop,
|
||||
Reconnect,
|
||||
Reconnect(Sockets),
|
||||
SetDns(Vec<IpAddr>),
|
||||
}
|
||||
|
||||
@@ -66,10 +66,14 @@ where
|
||||
loop {
|
||||
match self.rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Some(Command::SetDns(dns))) => self.tunnel.set_dns(dns, Instant::now()),
|
||||
Poll::Ready(Some(Command::Reconnect)) => {
|
||||
Poll::Ready(Some(Command::SetDns(dns))) => {
|
||||
if let Err(e) = self.tunnel.set_dns(dns) {
|
||||
tracing::warn!("Failed to update DNS: {e}");
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Command::Reconnect(sockets))) => {
|
||||
self.portal.reconnect();
|
||||
self.tunnel.reconnect();
|
||||
self.tunnel.reconnect(sockets);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ pub use connlib_shared::messages::ResourceDescription;
|
||||
pub use connlib_shared::{
|
||||
keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret,
|
||||
};
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
pub use firezone_tunnel::Sockets;
|
||||
pub use tracing_appender::non_blocking::WorkerGuard;
|
||||
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
@@ -12,6 +12,7 @@ use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
|
||||
mod eventloop;
|
||||
pub mod file_logger;
|
||||
@@ -37,6 +38,7 @@ impl Session {
|
||||
/// This connects to the portal a specified using [`LoginUrl`] and creates a wireguard tunnel using the provided private key.
|
||||
pub fn connect<CB: Callbacks + 'static>(
|
||||
url: LoginUrl,
|
||||
sockets: Sockets,
|
||||
private_key: StaticSecret,
|
||||
os_version_override: Option<String>,
|
||||
callbacks: CB,
|
||||
@@ -47,6 +49,7 @@ impl Session {
|
||||
|
||||
let connect_handle = handle.spawn(connect(
|
||||
url,
|
||||
sockets,
|
||||
private_key,
|
||||
os_version_override,
|
||||
callbacks.clone(),
|
||||
@@ -60,20 +63,34 @@ impl Session {
|
||||
|
||||
/// Attempts to reconnect a [`Session`].
|
||||
///
|
||||
/// This can and should be called by client applications on any network state changes.
|
||||
/// It is a signal to connlib to:
|
||||
/// Reconnecting a session will:
|
||||
///
|
||||
/// - validate all currently used network paths to relays and peers
|
||||
/// - ensure we are connected to the portal
|
||||
/// - Close and re-open a connection to the portal.
|
||||
/// - Refresh all allocations
|
||||
/// - Replace the currently used [`Sockets`] with the provided one
|
||||
///
|
||||
/// Reconnect is non-destructive and can be called several times in a row.
|
||||
/// # Implementation note
|
||||
///
|
||||
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
|
||||
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
|
||||
pub fn reconnect(&self) {
|
||||
let _ = self.channel.send(Command::Reconnect);
|
||||
/// The reason we replace [`Sockets`] are:
|
||||
///
|
||||
/// 1. On MacOS, as socket bound to the unspecified IP cannot send to interfaces attached after the socket has been created.
|
||||
/// 2. Switching between networks changes the 3-tuple of the client.
|
||||
/// The TURN protocol identifies a client's allocation based on the 3-tuple.
|
||||
/// Consequently, an allocation is invalid after switching networks and we clear the state.
|
||||
/// Changing the IP would be enough for that.
|
||||
/// However, if the user would now change _back_ to the previous network,
|
||||
/// the TURN server would recognise the old allocation but the client already lost all its state associated with it.
|
||||
/// To avoid race-conditions like this, we initialize a new [`Sockets`] instance which allocates a new port.
|
||||
pub fn reconnect(&self, sockets: Sockets) {
|
||||
let _ = self.channel.send(Command::Reconnect(sockets));
|
||||
}
|
||||
|
||||
/// Sets a new set of upstream DNS servers for this [`Session`].
|
||||
///
|
||||
/// Changing the DNS servers clears all cached DNS requests which may be disruptive to the UX.
|
||||
/// Clients should only call this when relevant.
|
||||
///
|
||||
/// The implementation is idempotent; calling it with the same set of servers is safe.
|
||||
pub fn set_dns(&self, new_dns: Vec<IpAddr>) {
|
||||
let _ = self.channel.send(Command::SetDns(new_dns));
|
||||
}
|
||||
@@ -91,6 +108,7 @@ impl Session {
|
||||
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
|
||||
async fn connect<CB>(
|
||||
url: LoginUrl,
|
||||
sockets: Sockets,
|
||||
private_key: StaticSecret,
|
||||
os_version_override: Option<String>,
|
||||
callbacks: CB,
|
||||
@@ -100,7 +118,7 @@ async fn connect<CB>(
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
let tunnel = ClientTunnel::new(private_key, callbacks.clone())?;
|
||||
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone());
|
||||
|
||||
let portal = PhoenixChannel::connect(
|
||||
Secret::new(url),
|
||||
|
||||
@@ -73,10 +73,6 @@ pub trait Callbacks: Clone + Send + Sync {
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
/// Protects the socket file descriptor from routing loops.
|
||||
#[cfg(target_os = "android")]
|
||||
fn protect_file_descriptor(&self, file_descriptor: std::os::fd::RawFd);
|
||||
|
||||
fn roll_log_file(&self) -> Option<PathBuf> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use ip_network_table::IpNetworkTable;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::utils::{earliest, stun, turn};
|
||||
use crate::ClientTunnel;
|
||||
use crate::{ClientEvent, ClientTunnel};
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use snownet::ClientNode;
|
||||
use std::collections::hash_map::Entry;
|
||||
@@ -39,22 +39,6 @@ const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120";
|
||||
// therefore, only the first time it's added that happens, after that it doesn't matter.
|
||||
const DNS_REFRESH_INTERVAL: Duration = Duration::from_secs(300);
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(crate) enum Event {
|
||||
SignalIceCandidate {
|
||||
conn_id: GatewayId,
|
||||
candidate: String,
|
||||
},
|
||||
ConnectionIntent {
|
||||
resource: ResourceId,
|
||||
connected_gateway_ids: HashSet<GatewayId>,
|
||||
},
|
||||
RefreshResources {
|
||||
connections: Vec<ReuseConnection>,
|
||||
},
|
||||
RefreshInterface,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct DnsResource {
|
||||
pub id: ResourceId,
|
||||
@@ -182,8 +166,16 @@ where
|
||||
}
|
||||
|
||||
/// Updates the system's dns
|
||||
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
|
||||
self.role_state.update_system_resolvers(new_dns, now);
|
||||
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>) -> connlib_shared::Result<()> {
|
||||
let dns_changed = self.role_state.update_system_resolvers(new_dns);
|
||||
|
||||
if !dns_changed {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.update_interface()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
@@ -335,7 +327,7 @@ pub struct ClientState {
|
||||
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
|
||||
buffered_events: VecDeque<Event>,
|
||||
buffered_events: VecDeque<ClientEvent>,
|
||||
interface_config: Option<InterfaceConfig>,
|
||||
buffered_packets: VecDeque<IpPacket<'static>>,
|
||||
|
||||
@@ -343,7 +335,6 @@ pub struct ClientState {
|
||||
buffered_dns_queries: VecDeque<DnsQuery<'static>>,
|
||||
|
||||
next_dns_refresh: Option<Instant>,
|
||||
next_system_resolver_refresh: Option<Instant>,
|
||||
|
||||
system_resolvers: Vec<IpAddr>,
|
||||
}
|
||||
@@ -375,7 +366,6 @@ impl ClientState {
|
||||
next_dns_refresh: Default::default(),
|
||||
node: ClientNode::new(private_key),
|
||||
system_resolvers: Default::default(),
|
||||
next_system_resolver_refresh: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -758,10 +748,11 @@ impl ClientState {
|
||||
|
||||
tracing::debug!("Sending connection intent");
|
||||
|
||||
self.buffered_events.push_back(Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids: gateways,
|
||||
});
|
||||
self.buffered_events
|
||||
.push_back(ClientEvent::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids: gateways,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
|
||||
@@ -837,15 +828,16 @@ impl ClientState {
|
||||
.map(|(_, res)| res.id)
|
||||
}
|
||||
|
||||
fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
|
||||
fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) -> bool {
|
||||
if !dns_updated(&self.system_resolvers, &new_dns) {
|
||||
tracing::debug!("Updated dns called but no change to system's resolver");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
tracing::info!("Found new system resolvers: {new_dns:?}");
|
||||
self.next_system_resolver_refresh = Some(now + std::time::Duration::from_millis(500));
|
||||
self.system_resolvers = new_dns;
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn poll_packets(&mut self) -> Option<IpPacket<'static>> {
|
||||
@@ -857,8 +849,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
pub fn poll_timeout(&mut self) -> Option<Instant> {
|
||||
let timeout = earliest(self.next_dns_refresh, self.node.poll_timeout());
|
||||
earliest(timeout, self.next_system_resolver_refresh)
|
||||
earliest(self.next_dns_refresh, self.node.poll_timeout())
|
||||
}
|
||||
|
||||
pub fn handle_timeout(&mut self, now: Instant) {
|
||||
@@ -889,7 +880,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
self.buffered_events
|
||||
.push_back(Event::RefreshResources { connections });
|
||||
.push_back(ClientEvent::RefreshResources { connections });
|
||||
|
||||
self.next_dns_refresh = Some(now + DNS_REFRESH_INTERVAL);
|
||||
}
|
||||
@@ -897,11 +888,6 @@ impl ClientState {
|
||||
Some(_) => {}
|
||||
}
|
||||
|
||||
if self.next_system_resolver_refresh.is_some_and(|e| now >= e) {
|
||||
self.buffered_events.push_back(Event::RefreshInterface);
|
||||
self.next_system_resolver_refresh = None;
|
||||
}
|
||||
|
||||
while let Some(event) = self.node.poll_event() {
|
||||
match event {
|
||||
snownet::Event::ConnectionFailed(id) => {
|
||||
@@ -910,16 +896,18 @@ impl ClientState {
|
||||
snownet::Event::SignalIceCandidate {
|
||||
connection,
|
||||
candidate,
|
||||
} => self.buffered_events.push_back(Event::SignalIceCandidate {
|
||||
conn_id: connection,
|
||||
candidate,
|
||||
}),
|
||||
} => self
|
||||
.buffered_events
|
||||
.push_back(ClientEvent::SignalIceCandidate {
|
||||
conn_id: connection,
|
||||
candidate,
|
||||
}),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_event(&mut self) -> Option<Event> {
|
||||
pub(crate) fn poll_event(&mut self) -> Option<ClientEvent> {
|
||||
self.buffered_events.pop_front()
|
||||
}
|
||||
|
||||
@@ -1102,44 +1090,29 @@ mod tests {
|
||||
fn update_system_dns_works() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
|
||||
|
||||
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterface));
|
||||
assert!(changed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_system_dns_without_change_is_a_no_op() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
client_state.poll_event();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
|
||||
let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
|
||||
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
assert!(client_state.poll_event().is_none());
|
||||
assert!(!changed)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_system_dns_with_change_works() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
client_state.poll_event();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
|
||||
let changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
|
||||
|
||||
client_state.update_system_resolvers(vec![ip("1.0.0.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterface));
|
||||
assert!(changed)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -43,17 +43,17 @@ pub enum Input<'a, I> {
|
||||
}
|
||||
|
||||
impl Io {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
pub fn new(sockets: Sockets) -> Self {
|
||||
Self {
|
||||
device: Device::new(),
|
||||
timeout: None,
|
||||
sockets: Sockets::new()?,
|
||||
sockets,
|
||||
upstream_dns_servers: HashMap::default(),
|
||||
forwarded_dns_queries: FuturesTupleSet::new(
|
||||
Duration::from_secs(60),
|
||||
DNS_QUERIES_QUEUE_SIZE,
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll<'b>(
|
||||
@@ -115,6 +115,10 @@ impl Io {
|
||||
&self.sockets
|
||||
}
|
||||
|
||||
pub(crate) fn set_sockets(&mut self, sockets: Sockets) {
|
||||
self.sockets = sockets;
|
||||
}
|
||||
|
||||
pub fn set_upstream_dns_servers(
|
||||
&mut self,
|
||||
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::{
|
||||
|
||||
pub use client::{ClientState, Request};
|
||||
pub use gateway::GatewayState;
|
||||
pub use sockets::Sockets;
|
||||
|
||||
mod client;
|
||||
mod device_channel;
|
||||
@@ -58,45 +59,27 @@ impl<CB> ClientTunnel<CB>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
|
||||
Ok(Self {
|
||||
io: new_io(&callbacks)?,
|
||||
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
|
||||
Self {
|
||||
io: Io::new(sockets),
|
||||
callbacks,
|
||||
role_state: ClientState::new(private_key),
|
||||
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reconnect(&mut self) {
|
||||
pub fn reconnect(&mut self, sockets: Sockets) {
|
||||
self.role_state.reconnect(Instant::now());
|
||||
self.io.set_sockets(sockets);
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<ClientEvent>> {
|
||||
loop {
|
||||
match self.role_state.poll_event() {
|
||||
Some(client::Event::RefreshInterface) => {
|
||||
self.update_interface()?;
|
||||
continue;
|
||||
}
|
||||
Some(client::Event::SignalIceCandidate { conn_id, candidate }) => {
|
||||
return Poll::Ready(Ok(ClientEvent::SignalIceCandidate { conn_id, candidate }))
|
||||
}
|
||||
Some(client::Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
}) => {
|
||||
return Poll::Ready(Ok(ClientEvent::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
}))
|
||||
}
|
||||
Some(client::Event::RefreshResources { connections }) => {
|
||||
return Poll::Ready(Ok(ClientEvent::RefreshResources { connections }))
|
||||
}
|
||||
None => {}
|
||||
if let Some(e) = self.role_state.poll_event() {
|
||||
return Poll::Ready(Ok(e));
|
||||
}
|
||||
|
||||
if let Some(packet) = self.role_state.poll_packets() {
|
||||
@@ -166,16 +149,16 @@ impl<CB> GatewayTunnel<CB>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
{
|
||||
pub fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
|
||||
Ok(Self {
|
||||
io: new_io(&callbacks)?,
|
||||
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
|
||||
Self {
|
||||
io: Io::new(sockets),
|
||||
callbacks,
|
||||
role_state: GatewayState::new(private_key),
|
||||
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<GatewayEvent>> {
|
||||
@@ -237,27 +220,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(target_os = "android"), allow(unused_variables))]
|
||||
fn new_io<CB>(callbacks: &CB) -> Result<Io>
|
||||
where
|
||||
CB: Callbacks,
|
||||
{
|
||||
let io = Io::new()?;
|
||||
|
||||
// TODO: Eventually, this should move into the `connlib-client-android` crate.
|
||||
#[cfg(target_os = "android")]
|
||||
{
|
||||
if let Some(ip4_socket) = io.sockets_ref().ip4_socket_fd() {
|
||||
callbacks.protect_file_descriptor(ip4_socket);
|
||||
}
|
||||
if let Some(ip6_socket) = io.sockets_ref().ip6_socket_fd() {
|
||||
callbacks.protect_file_descriptor(ip6_socket);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(io)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ClientEvent {
|
||||
SignalIceCandidate {
|
||||
|
||||
@@ -52,14 +52,14 @@ impl Sockets {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
#[cfg(unix)]
|
||||
pub fn ip4_socket_fd(&self) -> Option<std::os::fd::RawFd> {
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
self.socket_v4.as_ref().map(|s| s.socket.as_raw_fd())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "android")]
|
||||
#[cfg(unix)]
|
||||
pub fn ip6_socket_fd(&self) -> Option<std::os::fd::RawFd> {
|
||||
use std::os::fd::AsRawFd;
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use backoff::ExponentialBackoffBuilder;
|
||||
use clap::Parser;
|
||||
use connlib_shared::{get_user_agent, keypair, Callbacks, LoginUrl, StaticSecret};
|
||||
use firezone_cli_utils::{setup_global_subscriber, CommonArgs};
|
||||
use firezone_tunnel::GatewayTunnel;
|
||||
use firezone_tunnel::{GatewayTunnel, Sockets};
|
||||
use futures::{future, TryFutureExt};
|
||||
use secrecy::{Secret, SecretString};
|
||||
use std::convert::Infallible;
|
||||
@@ -91,7 +91,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
|
||||
}
|
||||
|
||||
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
|
||||
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
|
||||
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new()?, CallbackHandler);
|
||||
|
||||
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
|
||||
Secret::new(login),
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::client::{
|
||||
};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use arc_swap::ArcSwap;
|
||||
use connlib_client_shared::{file_logger, ResourceDescription};
|
||||
use connlib_client_shared::{file_logger, ResourceDescription, Sockets};
|
||||
use connlib_shared::{keypair, messages::ResourceId, LoginUrl, BUNDLE_ID};
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use std::{path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
||||
@@ -534,6 +534,7 @@ impl Controller {
|
||||
)?;
|
||||
let connlib = connlib_client_shared::Session::connect(
|
||||
login,
|
||||
Sockets::new()?,
|
||||
private_key,
|
||||
None,
|
||||
callback_handler.clone(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use connlib_client_shared::{file_logger, Callbacks, Session};
|
||||
use connlib_client_shared::{file_logger, Callbacks, Session, Sockets};
|
||||
use connlib_shared::{
|
||||
keypair,
|
||||
linux::{etc_resolv_conf, get_dns_control_from_env, DnsControlMethod},
|
||||
@@ -38,6 +38,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
let session = Session::connect(
|
||||
login,
|
||||
Sockets::new()?,
|
||||
private_key,
|
||||
None,
|
||||
callbacks.clone(),
|
||||
@@ -55,19 +56,19 @@ async fn main() -> Result<()> {
|
||||
if sigint.poll_recv(cx).is_ready() {
|
||||
tracing::debug!("Received SIGINT");
|
||||
|
||||
return Poll::Ready(());
|
||||
return Poll::Ready(std::io::Result::Ok(()));
|
||||
}
|
||||
|
||||
if sighup.poll_recv(cx).is_ready() {
|
||||
tracing::debug!("Received SIGHUP");
|
||||
|
||||
session.reconnect();
|
||||
session.reconnect(Sockets::new()?);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
})
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
session.disconnect();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user