refactor(connlib): implement new FFI guidelines (#4263)

This updates connlib to follow the new guidelines described in #4262. I
only made the bare-minimum changes to the clients. With these changes
`reconnect` should only be called when the network interface actually
changed, meaning clients have to be updated to reflect that.
This commit is contained in:
Thomas Eizinger
2024-03-23 15:13:05 +11:00
committed by GitHub
parent 703f07fed5
commit e628fa5d06
12 changed files with 162 additions and 168 deletions

View File

@@ -5,7 +5,7 @@
use connlib_client_shared::{
file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError,
ResourceDescription, Session,
ResourceDescription, Session, Sockets,
};
use jni::{
objects::{GlobalRef, JClass, JObject, JString, JValue},
@@ -67,6 +67,8 @@ pub enum CallbackError {
name: &'static str,
source: jni::errors::Error,
},
#[error(transparent)]
Io(#[from] io::Error),
}
impl CallbackHandler {
@@ -79,6 +81,18 @@ impl CallbackHandler {
.map_err(CallbackError::AttachCurrentThreadFailed)
.and_then(f)
}
fn protect_file_descriptor(&self, file_descriptor: RawFd) -> Result<(), CallbackError> {
self.env(|mut env| {
call_method(
&mut env,
&self.callback_handler,
"protectFileDescriptor",
"(I)V",
&[JValue::Int(file_descriptor)],
)
})
}
}
fn call_method(
@@ -227,20 +241,6 @@ impl Callbacks for CallbackHandler {
.expect("onUpdateRoutes callback failed")
}
#[cfg(target_os = "android")]
fn protect_file_descriptor(&self, file_descriptor: RawFd) {
self.env(|mut env| {
call_method(
&mut env,
&self.callback_handler,
"protectFileDescriptor",
"(I)V",
&[JValue::Int(file_descriptor)],
)
})
.expect("protectFileDescriptor callback failed");
}
fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) {
self.env(|mut env| {
let resource_list = env
@@ -326,6 +326,8 @@ enum ConnectError {
InvalidLoginUrl(#[from] LoginUrlError<url::ParseError>),
#[error("Unable to create tokio runtime: {0}")]
UnableToCreateRuntime(#[from] io::Error),
#[error(transparent)]
CallbackError(#[from] CallbackError),
}
macro_rules! string_from_jstring {
@@ -386,17 +388,28 @@ fn connect(
.enable_all()
.build()?;
let sockets = Sockets::new()?;
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
callback_handler.protect_file_descriptor(ip4_socket)?;
}
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
callback_handler.protect_file_descriptor(ip6_socket)?;
}
let session = Session::connect(
login,
sockets,
private_key,
Some(os_version),
callback_handler,
callback_handler.clone(),
Some(MAX_PARTITION_TIME),
runtime.handle().clone(),
)?;
Ok(SessionWrapper {
inner: session,
callbacks: callback_handler,
runtime,
})
}
@@ -450,11 +463,29 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co
pub struct SessionWrapper {
inner: Session,
callbacks: CallbackHandler,
#[allow(dead_code)] // Only here so we don't drop the memory early.
runtime: Runtime,
}
impl SessionWrapper {
fn reconnect(&self) -> Result<(), CallbackError> {
let sockets = Sockets::new()?;
if let Some(ip4_socket) = sockets.ip4_socket_fd() {
self.callbacks.protect_file_descriptor(ip4_socket)?;
}
if let Some(ip6_socket) = sockets.ip6_socket_fd() {
self.callbacks.protect_file_descriptor(ip6_socket)?;
}
self.inner.reconnect(sockets);
Ok(())
}
}
/// # Safety
/// Pointers must be valid
#[allow(non_snake_case)]
@@ -497,9 +528,11 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_se
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_reconnect(
_: JNIEnv,
mut env: JNIEnv,
_: JClass,
session: *const SessionWrapper,
) {
(*session).inner.reconnect();
if let Err(e) = (*session).reconnect() {
throw(&mut env, "java/lang/Exception", e.to_string());
}
}

View File

@@ -3,6 +3,7 @@
use connlib_client_shared::{
file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, ResourceDescription, Session,
Sockets,
};
use secrecy::SecretString;
use std::{
@@ -204,6 +205,7 @@ impl WrappedSession {
let session = Session::connect(
login,
Sockets::new().map_err(|err| err.to_string())?,
private_key,
os_version_override,
CallbackHandler {

View File

@@ -10,7 +10,7 @@ use connlib_shared::{
messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId},
Callbacks,
};
use firezone_tunnel::ClientTunnel;
use firezone_tunnel::{ClientTunnel, Sockets};
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
@@ -18,7 +18,7 @@ use std::{
net::IpAddr,
path::PathBuf,
task::{Context, Poll},
time::{Duration, Instant},
time::Duration,
};
use tokio::time::{Interval, MissedTickBehavior};
use url::Url;
@@ -37,7 +37,7 @@ pub struct Eventloop<C: Callbacks> {
/// Commands that can be sent to the [`Eventloop`].
pub enum Command {
Stop,
Reconnect,
Reconnect(Sockets),
SetDns(Vec<IpAddr>),
}
@@ -66,10 +66,14 @@ where
loop {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::SetDns(dns))) => self.tunnel.set_dns(dns, Instant::now()),
Poll::Ready(Some(Command::Reconnect)) => {
Poll::Ready(Some(Command::SetDns(dns))) => {
if let Err(e) = self.tunnel.set_dns(dns) {
tracing::warn!("Failed to update DNS: {e}");
}
}
Poll::Ready(Some(Command::Reconnect(sockets))) => {
self.portal.reconnect();
self.tunnel.reconnect();
self.tunnel.reconnect(sockets);
continue;
}

View File

@@ -3,7 +3,7 @@ pub use connlib_shared::messages::ResourceDescription;
pub use connlib_shared::{
keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret,
};
use tokio::sync::mpsc::UnboundedReceiver;
pub use firezone_tunnel::Sockets;
pub use tracing_appender::non_blocking::WorkerGuard;
use backoff::ExponentialBackoffBuilder;
@@ -12,6 +12,7 @@ use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use std::net::IpAddr;
use std::time::Duration;
use tokio::sync::mpsc::UnboundedReceiver;
mod eventloop;
pub mod file_logger;
@@ -37,6 +38,7 @@ impl Session {
/// This connects to the portal a specified using [`LoginUrl`] and creates a wireguard tunnel using the provided private key.
pub fn connect<CB: Callbacks + 'static>(
url: LoginUrl,
sockets: Sockets,
private_key: StaticSecret,
os_version_override: Option<String>,
callbacks: CB,
@@ -47,6 +49,7 @@ impl Session {
let connect_handle = handle.spawn(connect(
url,
sockets,
private_key,
os_version_override,
callbacks.clone(),
@@ -60,20 +63,34 @@ impl Session {
/// Attempts to reconnect a [`Session`].
///
/// This can and should be called by client applications on any network state changes.
/// It is a signal to connlib to:
/// Reconnecting a session will:
///
/// - validate all currently used network paths to relays and peers
/// - ensure we are connected to the portal
/// - Close and re-open a connection to the portal.
/// - Refresh all allocations
/// - Replace the currently used [`Sockets`] with the provided one
///
/// Reconnect is non-destructive and can be called several times in a row.
/// # Implementation note
///
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
pub fn reconnect(&self) {
let _ = self.channel.send(Command::Reconnect);
/// The reason we replace [`Sockets`] are:
///
/// 1. On MacOS, as socket bound to the unspecified IP cannot send to interfaces attached after the socket has been created.
/// 2. Switching between networks changes the 3-tuple of the client.
/// The TURN protocol identifies a client's allocation based on the 3-tuple.
/// Consequently, an allocation is invalid after switching networks and we clear the state.
/// Changing the IP would be enough for that.
/// However, if the user would now change _back_ to the previous network,
/// the TURN server would recognise the old allocation but the client already lost all its state associated with it.
/// To avoid race-conditions like this, we initialize a new [`Sockets`] instance which allocates a new port.
pub fn reconnect(&self, sockets: Sockets) {
let _ = self.channel.send(Command::Reconnect(sockets));
}
/// Sets a new set of upstream DNS servers for this [`Session`].
///
/// Changing the DNS servers clears all cached DNS requests which may be disruptive to the UX.
/// Clients should only call this when relevant.
///
/// The implementation is idempotent; calling it with the same set of servers is safe.
pub fn set_dns(&self, new_dns: Vec<IpAddr>) {
let _ = self.channel.send(Command::SetDns(new_dns));
}
@@ -91,6 +108,7 @@ impl Session {
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
async fn connect<CB>(
url: LoginUrl,
sockets: Sockets,
private_key: StaticSecret,
os_version_override: Option<String>,
callbacks: CB,
@@ -100,7 +118,7 @@ async fn connect<CB>(
where
CB: Callbacks + 'static,
{
let tunnel = ClientTunnel::new(private_key, callbacks.clone())?;
let tunnel = ClientTunnel::new(private_key, sockets, callbacks.clone());
let portal = PhoenixChannel::connect(
Secret::new(url),

View File

@@ -73,10 +73,6 @@ pub trait Callbacks: Clone + Send + Sync {
std::process::exit(0);
}
/// Protects the socket file descriptor from routing loops.
#[cfg(target_os = "android")]
fn protect_file_descriptor(&self, file_descriptor: std::os::fd::RawFd);
fn roll_log_file(&self) -> Option<PathBuf> {
None
}

View File

@@ -16,7 +16,7 @@ use ip_network_table::IpNetworkTable;
use itertools::Itertools;
use crate::utils::{earliest, stun, turn};
use crate::ClientTunnel;
use crate::{ClientEvent, ClientTunnel};
use secrecy::{ExposeSecret as _, Secret};
use snownet::ClientNode;
use std::collections::hash_map::Entry;
@@ -39,22 +39,6 @@ const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120";
// therefore, only the first time it's added that happens, after that it doesn't matter.
const DNS_REFRESH_INTERVAL: Duration = Duration::from_secs(300);
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum Event {
SignalIceCandidate {
conn_id: GatewayId,
candidate: String,
},
ConnectionIntent {
resource: ResourceId,
connected_gateway_ids: HashSet<GatewayId>,
},
RefreshResources {
connections: Vec<ReuseConnection>,
},
RefreshInterface,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct DnsResource {
pub id: ResourceId,
@@ -182,8 +166,16 @@ where
}
/// Updates the system's dns
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
self.role_state.update_system_resolvers(new_dns, now);
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>) -> connlib_shared::Result<()> {
let dns_changed = self.role_state.update_system_resolvers(new_dns);
if !dns_changed {
return Ok(());
}
self.update_interface()?;
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
@@ -335,7 +327,7 @@ pub struct ClientState {
dns_mapping: BiMap<IpAddr, DnsServer>,
buffered_events: VecDeque<Event>,
buffered_events: VecDeque<ClientEvent>,
interface_config: Option<InterfaceConfig>,
buffered_packets: VecDeque<IpPacket<'static>>,
@@ -343,7 +335,6 @@ pub struct ClientState {
buffered_dns_queries: VecDeque<DnsQuery<'static>>,
next_dns_refresh: Option<Instant>,
next_system_resolver_refresh: Option<Instant>,
system_resolvers: Vec<IpAddr>,
}
@@ -375,7 +366,6 @@ impl ClientState {
next_dns_refresh: Default::default(),
node: ClientNode::new(private_key),
system_resolvers: Default::default(),
next_system_resolver_refresh: Default::default(),
}
}
@@ -758,10 +748,11 @@ impl ClientState {
tracing::debug!("Sending connection intent");
self.buffered_events.push_back(Event::ConnectionIntent {
resource,
connected_gateway_ids: gateways,
});
self.buffered_events
.push_back(ClientEvent::ConnectionIntent {
resource,
connected_gateway_ids: gateways,
});
}
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
@@ -837,15 +828,16 @@ impl ClientState {
.map(|(_, res)| res.id)
}
fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>) -> bool {
if !dns_updated(&self.system_resolvers, &new_dns) {
tracing::debug!("Updated dns called but no change to system's resolver");
return;
return false;
}
tracing::info!("Found new system resolvers: {new_dns:?}");
self.next_system_resolver_refresh = Some(now + std::time::Duration::from_millis(500));
self.system_resolvers = new_dns;
true
}
pub fn poll_packets(&mut self) -> Option<IpPacket<'static>> {
@@ -857,8 +849,7 @@ impl ClientState {
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
let timeout = earliest(self.next_dns_refresh, self.node.poll_timeout());
earliest(timeout, self.next_system_resolver_refresh)
earliest(self.next_dns_refresh, self.node.poll_timeout())
}
pub fn handle_timeout(&mut self, now: Instant) {
@@ -889,7 +880,7 @@ impl ClientState {
}
self.buffered_events
.push_back(Event::RefreshResources { connections });
.push_back(ClientEvent::RefreshResources { connections });
self.next_dns_refresh = Some(now + DNS_REFRESH_INTERVAL);
}
@@ -897,11 +888,6 @@ impl ClientState {
Some(_) => {}
}
if self.next_system_resolver_refresh.is_some_and(|e| now >= e) {
self.buffered_events.push_back(Event::RefreshInterface);
self.next_system_resolver_refresh = None;
}
while let Some(event) = self.node.poll_event() {
match event {
snownet::Event::ConnectionFailed(id) => {
@@ -910,16 +896,18 @@ impl ClientState {
snownet::Event::SignalIceCandidate {
connection,
candidate,
} => self.buffered_events.push_back(Event::SignalIceCandidate {
conn_id: connection,
candidate,
}),
} => self
.buffered_events
.push_back(ClientEvent::SignalIceCandidate {
conn_id: connection,
candidate,
}),
_ => {}
}
}
}
pub(crate) fn poll_event(&mut self) -> Option<Event> {
pub(crate) fn poll_event(&mut self) -> Option<ClientEvent> {
self.buffered_events.pop_front()
}
@@ -1102,44 +1090,29 @@ mod tests {
fn update_system_dns_works() {
let mut client_state = ClientState::for_test();
let now = Instant::now();
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
let now = now + Duration::from_millis(500);
client_state.handle_timeout(now);
let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterface));
assert!(changed);
}
#[test]
fn update_system_dns_without_change_is_a_no_op() {
let mut client_state = ClientState::for_test();
let now = Instant::now();
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
let now = now + Duration::from_millis(500);
client_state.handle_timeout(now);
client_state.poll_event();
client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
let now = now + Duration::from_millis(500);
client_state.handle_timeout(now);
assert!(client_state.poll_event().is_none());
assert!(!changed)
}
#[test]
fn update_system_dns_with_change_works() {
let mut client_state = ClientState::for_test();
let now = Instant::now();
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
let now = now + Duration::from_millis(500);
client_state.handle_timeout(now);
client_state.poll_event();
client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
client_state.update_system_resolvers(vec![ip("1.0.0.1")], now);
let now = now + Duration::from_millis(500);
client_state.handle_timeout(now);
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterface));
assert!(changed)
}
#[test]

View File

@@ -43,17 +43,17 @@ pub enum Input<'a, I> {
}
impl Io {
pub fn new() -> io::Result<Self> {
Ok(Self {
pub fn new(sockets: Sockets) -> Self {
Self {
device: Device::new(),
timeout: None,
sockets: Sockets::new()?,
sockets,
upstream_dns_servers: HashMap::default(),
forwarded_dns_queries: FuturesTupleSet::new(
Duration::from_secs(60),
DNS_QUERIES_QUEUE_SIZE,
),
})
}
}
pub fn poll<'b>(
@@ -115,6 +115,10 @@ impl Io {
&self.sockets
}
pub(crate) fn set_sockets(&mut self, sockets: Sockets) {
self.sockets = sockets;
}
pub fn set_upstream_dns_servers(
&mut self,
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,

View File

@@ -17,6 +17,7 @@ use std::{
pub use client::{ClientState, Request};
pub use gateway::GatewayState;
pub use sockets::Sockets;
mod client;
mod device_channel;
@@ -58,45 +59,27 @@ impl<CB> ClientTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
Ok(Self {
io: new_io(&callbacks)?,
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
Self {
io: Io::new(sockets),
callbacks,
role_state: ClientState::new(private_key),
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
})
}
}
pub fn reconnect(&mut self) {
pub fn reconnect(&mut self, sockets: Sockets) {
self.role_state.reconnect(Instant::now());
self.io.set_sockets(sockets);
}
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<ClientEvent>> {
loop {
match self.role_state.poll_event() {
Some(client::Event::RefreshInterface) => {
self.update_interface()?;
continue;
}
Some(client::Event::SignalIceCandidate { conn_id, candidate }) => {
return Poll::Ready(Ok(ClientEvent::SignalIceCandidate { conn_id, candidate }))
}
Some(client::Event::ConnectionIntent {
resource,
connected_gateway_ids,
}) => {
return Poll::Ready(Ok(ClientEvent::ConnectionIntent {
resource,
connected_gateway_ids,
}))
}
Some(client::Event::RefreshResources { connections }) => {
return Poll::Ready(Ok(ClientEvent::RefreshResources { connections }))
}
None => {}
if let Some(e) = self.role_state.poll_event() {
return Poll::Ready(Ok(e));
}
if let Some(packet) = self.role_state.poll_packets() {
@@ -166,16 +149,16 @@ impl<CB> GatewayTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn new(private_key: StaticSecret, callbacks: CB) -> Result<Self> {
Ok(Self {
io: new_io(&callbacks)?,
pub fn new(private_key: StaticSecret, sockets: Sockets, callbacks: CB) -> Self {
Self {
io: Io::new(sockets),
callbacks,
role_state: GatewayState::new(private_key),
write_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
})
}
}
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<GatewayEvent>> {
@@ -237,27 +220,6 @@ where
}
}
#[cfg_attr(not(target_os = "android"), allow(unused_variables))]
fn new_io<CB>(callbacks: &CB) -> Result<Io>
where
CB: Callbacks,
{
let io = Io::new()?;
// TODO: Eventually, this should move into the `connlib-client-android` crate.
#[cfg(target_os = "android")]
{
if let Some(ip4_socket) = io.sockets_ref().ip4_socket_fd() {
callbacks.protect_file_descriptor(ip4_socket);
}
if let Some(ip6_socket) = io.sockets_ref().ip6_socket_fd() {
callbacks.protect_file_descriptor(ip6_socket);
}
}
Ok(io)
}
#[derive(Debug, PartialEq, Eq)]
pub enum ClientEvent {
SignalIceCandidate {

View File

@@ -52,14 +52,14 @@ impl Sockets {
}
}
#[cfg(target_os = "android")]
#[cfg(unix)]
pub fn ip4_socket_fd(&self) -> Option<std::os::fd::RawFd> {
use std::os::fd::AsRawFd;
self.socket_v4.as_ref().map(|s| s.socket.as_raw_fd())
}
#[cfg(target_os = "android")]
#[cfg(unix)]
pub fn ip6_socket_fd(&self) -> Option<std::os::fd::RawFd> {
use std::os::fd::AsRawFd;

View File

@@ -5,7 +5,7 @@ use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_shared::{get_user_agent, keypair, Callbacks, LoginUrl, StaticSecret};
use firezone_cli_utils::{setup_global_subscriber, CommonArgs};
use firezone_tunnel::GatewayTunnel;
use firezone_tunnel::{GatewayTunnel, Sockets};
use futures::{future, TryFutureExt};
use secrecy::{Secret, SecretString};
use std::convert::Infallible;
@@ -91,7 +91,7 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
}
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
let mut tunnel = GatewayTunnel::new(private_key, Sockets::new()?, CallbackHandler);
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
Secret::new(login),

View File

@@ -10,7 +10,7 @@ use crate::client::{
};
use anyhow::{bail, Context, Result};
use arc_swap::ArcSwap;
use connlib_client_shared::{file_logger, ResourceDescription};
use connlib_client_shared::{file_logger, ResourceDescription, Sockets};
use connlib_shared::{keypair, messages::ResourceId, LoginUrl, BUNDLE_ID};
use secrecy::{ExposeSecret, SecretString};
use std::{path::PathBuf, str::FromStr, sync::Arc, time::Duration};
@@ -534,6 +534,7 @@ impl Controller {
)?;
let connlib = connlib_client_shared::Session::connect(
login,
Sockets::new()?,
private_key,
None,
callback_handler.clone(),

View File

@@ -1,6 +1,6 @@
use anyhow::{Context, Result};
use clap::Parser;
use connlib_client_shared::{file_logger, Callbacks, Session};
use connlib_client_shared::{file_logger, Callbacks, Session, Sockets};
use connlib_shared::{
keypair,
linux::{etc_resolv_conf, get_dns_control_from_env, DnsControlMethod},
@@ -38,6 +38,7 @@ async fn main() -> Result<()> {
let session = Session::connect(
login,
Sockets::new()?,
private_key,
None,
callbacks.clone(),
@@ -55,19 +56,19 @@ async fn main() -> Result<()> {
if sigint.poll_recv(cx).is_ready() {
tracing::debug!("Received SIGINT");
return Poll::Ready(());
return Poll::Ready(std::io::Result::Ok(()));
}
if sighup.poll_recv(cx).is_ready() {
tracing::debug!("Received SIGHUP");
session.reconnect();
session.reconnect(Sockets::new()?);
continue;
}
return Poll::Pending;
})
.await;
.await?;
session.disconnect();