mirror of
https://github.com/outbackdingo/firezone.git
synced 2026-01-27 18:18:55 +00:00
refactor(connlib): explicitly set DNS from clients instead of requesting it via callback (#4240)
Extracted from #4163 Dependant PRs: #4198 #4133 #4163
This commit is contained in:
@@ -79,6 +79,21 @@ impl CallbackHandler {
|
||||
.map_err(CallbackError::AttachCurrentThreadFailed)
|
||||
.and_then(f)
|
||||
}
|
||||
|
||||
fn get_system_default_resolvers(&self) -> Vec<IpAddr> {
|
||||
self.env(|mut env| {
|
||||
let name = "getSystemDefaultResolvers";
|
||||
let addrs = env
|
||||
.call_method(&self.callback_handler, name, "()[[B", &[])
|
||||
.and_then(JValueGen::l)
|
||||
.and_then(|arr| convert_byte_array_array(&mut env, arr.into()))
|
||||
.map_err(|source| CallbackError::CallMethodFailed { name, source })?;
|
||||
|
||||
Ok(Some(addrs.iter().filter_map(|v| to_ip(v)).collect()))
|
||||
})
|
||||
.expect("getSystemDefaultResolvers callback failed")
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
fn call_method(
|
||||
@@ -286,20 +301,6 @@ impl Callbacks for CallbackHandler {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
|
||||
self.env(|mut env| {
|
||||
let name = "getSystemDefaultResolvers";
|
||||
let addrs = env
|
||||
.call_method(&self.callback_handler, name, "()[[B", &[])
|
||||
.and_then(JValueGen::l)
|
||||
.and_then(|arr| convert_byte_array_array(&mut env, arr.into()))
|
||||
.map_err(|source| CallbackError::CallMethodFailed { name, source })?;
|
||||
|
||||
Ok(Some(addrs.iter().filter_map(|v| to_ip(v)).collect()))
|
||||
})
|
||||
.expect("getSystemDefaultResolvers callback failed")
|
||||
}
|
||||
}
|
||||
|
||||
fn to_ip(val: &[u8]) -> Option<IpAddr> {
|
||||
@@ -427,11 +428,13 @@ fn connect(
|
||||
login,
|
||||
private_key,
|
||||
Some(os_version),
|
||||
callback_handler,
|
||||
callback_handler.clone(),
|
||||
Some(MAX_PARTITION_TIME),
|
||||
runtime.handle().clone(),
|
||||
)?;
|
||||
|
||||
session.set_dns(callback_handler.get_system_default_resolvers());
|
||||
|
||||
Ok(SessionWrapper {
|
||||
inner: session,
|
||||
runtime,
|
||||
|
||||
@@ -69,6 +69,7 @@ mod ffi {
|
||||
#[swift_bridge(swift_name = "onDisconnect")]
|
||||
fn on_disconnect(&self, error: String);
|
||||
|
||||
// TODO: remove in favor of set_dns
|
||||
#[swift_bridge(swift_name = "getSystemDefaultResolvers")]
|
||||
fn get_system_default_resolvers(&self) -> String;
|
||||
}
|
||||
@@ -141,19 +142,6 @@ impl Callbacks for CallbackHandler {
|
||||
self.inner.on_disconnect(error.to_string());
|
||||
}
|
||||
|
||||
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
|
||||
let resolvers_json = self.inner.get_system_default_resolvers();
|
||||
tracing::debug!(
|
||||
"get_system_default_resolvers returned: {:?}",
|
||||
resolvers_json
|
||||
);
|
||||
|
||||
let resolvers: Vec<IpAddr> = serde_json::from_str(&resolvers_json)
|
||||
.expect("developer error: failed to deserialize resolvers");
|
||||
|
||||
Some(resolvers)
|
||||
}
|
||||
|
||||
fn roll_log_file(&self) -> Option<PathBuf> {
|
||||
self.handle.roll_to_new_file().unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to roll over to new log file: {e}");
|
||||
@@ -193,6 +181,10 @@ impl WrappedSession {
|
||||
let handle = init_logging(log_dir.into(), log_filter).map_err(|e| e.to_string())?;
|
||||
let secret = SecretString::from(token);
|
||||
|
||||
let resolvers_json = callback_handler.get_system_default_resolvers();
|
||||
let resolvers: Vec<IpAddr> = serde_json::from_str(&resolvers_json)
|
||||
.expect("developer error: failed to deserialize resolvers");
|
||||
|
||||
let (private_key, public_key) = keypair();
|
||||
let login = LoginUrl::client(
|
||||
api_url.as_str(),
|
||||
@@ -223,6 +215,8 @@ impl WrappedSession {
|
||||
)
|
||||
.map_err(|err| err.to_string())?;
|
||||
|
||||
session.set_dns(resolvers);
|
||||
|
||||
Ok(Self {
|
||||
inner: session,
|
||||
runtime,
|
||||
|
||||
@@ -15,11 +15,12 @@ use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io,
|
||||
net::IpAddr,
|
||||
path::PathBuf,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::time::{Instant, Interval, MissedTickBehavior};
|
||||
use tokio::time::{Interval, MissedTickBehavior};
|
||||
use url::Url;
|
||||
|
||||
pub struct Eventloop<C: Callbacks> {
|
||||
@@ -27,7 +28,7 @@ pub struct Eventloop<C: Callbacks> {
|
||||
tunnel_init: bool,
|
||||
|
||||
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
|
||||
connection_intents: SentConnectionIntents,
|
||||
log_upload_interval: tokio::time::Interval,
|
||||
@@ -37,13 +38,14 @@ pub struct Eventloop<C: Callbacks> {
|
||||
pub enum Command {
|
||||
Stop,
|
||||
Reconnect,
|
||||
SetDns(Vec<IpAddr>),
|
||||
}
|
||||
|
||||
impl<C: Callbacks> Eventloop<C> {
|
||||
pub(crate) fn new(
|
||||
tunnel: ClientTunnel<C>,
|
||||
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tunnel,
|
||||
@@ -65,6 +67,7 @@ where
|
||||
loop {
|
||||
match self.rx.poll_recv(cx) {
|
||||
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Some(Command::SetDns(dns))) => self.tunnel.set_dns(dns, Instant::now()),
|
||||
Poll::Ready(Some(Command::Reconnect)) => {
|
||||
self.portal.reconnect();
|
||||
self.tunnel.reconnect();
|
||||
@@ -180,7 +183,7 @@ where
|
||||
resources,
|
||||
}) => {
|
||||
if !self.tunnel_init {
|
||||
if let Err(e) = self.tunnel.set_interface(&interface) {
|
||||
if let Err(e) = self.tunnel.set_interface(interface) {
|
||||
tracing::warn!("Failed to set interface on tunnel: {e}");
|
||||
return;
|
||||
}
|
||||
@@ -364,7 +367,7 @@ async fn upload(_path: PathBuf, _url: Url) -> io::Result<()> {
|
||||
|
||||
fn upload_interval() -> Interval {
|
||||
let duration = upload_interval_duration_from_env_or_default();
|
||||
let mut interval = tokio::time::interval_at(Instant::now() + duration, duration);
|
||||
let mut interval = tokio::time::interval_at(tokio::time::Instant::now() + duration, duration);
|
||||
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
|
||||
|
||||
interval
|
||||
|
||||
@@ -3,12 +3,14 @@ pub use connlib_shared::messages::ResourceDescription;
|
||||
pub use connlib_shared::{
|
||||
keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret,
|
||||
};
|
||||
use tokio::sync::mpsc::UnboundedReceiver;
|
||||
pub use tracing_appender::non_blocking::WorkerGuard;
|
||||
|
||||
use backoff::ExponentialBackoffBuilder;
|
||||
use connlib_shared::get_user_agent;
|
||||
use firezone_tunnel::ClientTunnel;
|
||||
use phoenix_channel::PhoenixChannel;
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
mod eventloop;
|
||||
@@ -26,7 +28,7 @@ use tokio::task::JoinHandle;
|
||||
///
|
||||
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
|
||||
pub struct Session {
|
||||
channel: tokio::sync::mpsc::Sender<Command>,
|
||||
channel: tokio::sync::mpsc::UnboundedSender<Command>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -41,7 +43,7 @@ impl Session {
|
||||
max_partition_time: Option<Duration>,
|
||||
handle: tokio::runtime::Handle,
|
||||
) -> connlib_shared::Result<Self> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(1);
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
let connect_handle = handle.spawn(connect(
|
||||
url,
|
||||
@@ -68,15 +70,19 @@ impl Session {
|
||||
///
|
||||
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
|
||||
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
|
||||
pub fn reconnect(&mut self) {
|
||||
let _ = self.channel.try_send(Command::Reconnect);
|
||||
pub fn reconnect(&self) {
|
||||
let _ = self.channel.send(Command::Reconnect);
|
||||
}
|
||||
|
||||
pub fn set_dns(&self, new_dns: Vec<IpAddr>) {
|
||||
let _ = self.channel.send(Command::SetDns(new_dns));
|
||||
}
|
||||
|
||||
/// Disconnect a [`Session`].
|
||||
///
|
||||
/// This consumes [`Session`] which cleans up all state associated with it.
|
||||
pub fn disconnect(self) {
|
||||
let _ = self.channel.try_send(Command::Stop);
|
||||
let _ = self.channel.send(Command::Stop);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +95,7 @@ async fn connect<CB>(
|
||||
os_version_override: Option<String>,
|
||||
callbacks: CB,
|
||||
max_partition_time: Option<Duration>,
|
||||
rx: tokio::sync::mpsc::Receiver<Command>,
|
||||
rx: UnboundedReceiver<Command>,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
CB: Callbacks + 'static,
|
||||
|
||||
@@ -73,14 +73,6 @@ pub trait Callbacks: Clone + Send + Sync {
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
/// Returns the system's default resolver(s)
|
||||
///
|
||||
/// It's okay for clients to include Firezone's own DNS here, e.g. 100.100.111.1.
|
||||
/// connlib internally filters them out.
|
||||
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Protects the socket file descriptor from routing loops.
|
||||
#[cfg(target_os = "android")]
|
||||
fn protect_file_descriptor(&self, file_descriptor: std::os::fd::RawFd);
|
||||
|
||||
@@ -16,7 +16,7 @@ use ip_network_table::IpNetworkTable;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::utils::{earliest, stun, turn};
|
||||
use crate::{ClientEvent, ClientTunnel};
|
||||
use crate::ClientTunnel;
|
||||
use secrecy::{ExposeSecret as _, Secret};
|
||||
use snownet::ClientNode;
|
||||
use std::collections::hash_map::Entry;
|
||||
@@ -39,6 +39,22 @@ const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120";
|
||||
// therefore, only the first time it's added that happens, after that it doesn't matter.
|
||||
const DNS_REFRESH_INTERVAL: Duration = Duration::from_secs(300);
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(crate) enum Event {
|
||||
SignalIceCandidate {
|
||||
conn_id: GatewayId,
|
||||
candidate: String,
|
||||
},
|
||||
ConnectionIntent {
|
||||
resource: ResourceId,
|
||||
connected_gateway_ids: HashSet<GatewayId>,
|
||||
},
|
||||
RefreshResources {
|
||||
connections: Vec<ReuseConnection>,
|
||||
},
|
||||
RefreshInterfance,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct DnsResource {
|
||||
pub id: ResourceId,
|
||||
@@ -168,15 +184,26 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
/// Updates the system's dns
|
||||
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
|
||||
self.role_state.update_system_resolvers(new_dns, now);
|
||||
}
|
||||
|
||||
/// Sets the interface configuration and starts background tasks.
|
||||
#[tracing::instrument(level = "trace", skip(self))]
|
||||
pub fn set_interface(&mut self, config: &InterfaceConfig) -> connlib_shared::Result<()> {
|
||||
self.role_state.interface_config = Some(config.clone());
|
||||
pub fn set_interface(&mut self, config: InterfaceConfig) -> connlib_shared::Result<()> {
|
||||
self.role_state.interface_config = Some(config);
|
||||
self.update_interface()
|
||||
}
|
||||
|
||||
pub(crate) fn update_interface(&mut self) -> connlib_shared::Result<()> {
|
||||
let Some(config) = self.role_state.interface_config.as_ref().cloned() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let effective_dns_servers = effective_dns_servers(
|
||||
config.upstream_dns.clone(),
|
||||
self.callbacks
|
||||
.get_system_default_resolvers()
|
||||
.unwrap_or_default(),
|
||||
self.role_state.system_resolvers.clone(),
|
||||
);
|
||||
|
||||
let dns_mapping = sentinel_dns_mapping(&effective_dns_servers);
|
||||
@@ -186,7 +213,7 @@ where
|
||||
let callbacks = self.callbacks.clone();
|
||||
|
||||
self.io.device_mut().initialize(
|
||||
config,
|
||||
&config,
|
||||
// We can just sort in here because sentinel ips are created in order
|
||||
dns_mapping.left_values().copied().sorted().collect(),
|
||||
&callbacks,
|
||||
@@ -322,7 +349,7 @@ pub struct ClientState {
|
||||
|
||||
dns_mapping: BiMap<IpAddr, DnsServer>,
|
||||
|
||||
buffered_events: VecDeque<ClientEvent>,
|
||||
buffered_events: VecDeque<Event>,
|
||||
interface_config: Option<InterfaceConfig>,
|
||||
buffered_packets: VecDeque<IpPacket<'static>>,
|
||||
|
||||
@@ -330,6 +357,9 @@ pub struct ClientState {
|
||||
buffered_dns_queries: VecDeque<DnsQuery<'static>>,
|
||||
|
||||
next_dns_refresh: Option<Instant>,
|
||||
next_system_resolver_refresh: Option<Instant>,
|
||||
|
||||
system_resolvers: Vec<IpAddr>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -358,6 +388,8 @@ impl ClientState {
|
||||
buffered_dns_queries: Default::default(),
|
||||
next_dns_refresh: Default::default(),
|
||||
node: ClientNode::new(private_key),
|
||||
system_resolvers: Default::default(),
|
||||
next_system_resolver_refresh: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,19 +771,21 @@ impl ClientState {
|
||||
|
||||
tracing::debug!("Sending connection intent");
|
||||
|
||||
self.buffered_events
|
||||
.push_back(ClientEvent::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids: gateways,
|
||||
});
|
||||
self.buffered_events.push_back(Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids: gateways,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
|
||||
self.resources_gateways.get(resource).copied()
|
||||
}
|
||||
|
||||
fn set_dns_mapping(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.dns_mapping = mapping.clone();
|
||||
fn set_dns_mapping(&mut self, new_mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.dns_mapping = new_mapping.clone();
|
||||
self.peers
|
||||
.iter_mut()
|
||||
.for_each(|p| p.transform.set_dns(new_mapping.clone()));
|
||||
}
|
||||
|
||||
pub fn dns_mapping(&self) -> BiMap<IpAddr, DnsServer> {
|
||||
@@ -816,6 +850,16 @@ impl ClientState {
|
||||
.map(|(_, res)| res.id)
|
||||
}
|
||||
|
||||
fn update_system_resolvers(&mut self, new_dns: Vec<IpAddr>, now: Instant) {
|
||||
if !dns_updated(&self.system_resolvers, &new_dns) {
|
||||
tracing::debug!("Updated dns called but no change to system's resolver");
|
||||
return;
|
||||
}
|
||||
|
||||
self.next_system_resolver_refresh = Some(now + std::time::Duration::from_millis(500));
|
||||
self.system_resolvers = new_dns;
|
||||
}
|
||||
|
||||
pub fn poll_packets(&mut self) -> Option<IpPacket<'static>> {
|
||||
self.buffered_packets.pop_front()
|
||||
}
|
||||
@@ -825,7 +869,8 @@ impl ClientState {
|
||||
}
|
||||
|
||||
pub fn poll_timeout(&mut self) -> Option<Instant> {
|
||||
earliest(self.next_dns_refresh, self.node.poll_timeout())
|
||||
let timeout = earliest(self.next_dns_refresh, self.node.poll_timeout());
|
||||
earliest(timeout, self.next_system_resolver_refresh)
|
||||
}
|
||||
|
||||
pub fn handle_timeout(&mut self, now: Instant) {
|
||||
@@ -856,7 +901,7 @@ impl ClientState {
|
||||
}
|
||||
|
||||
self.buffered_events
|
||||
.push_back(ClientEvent::RefreshResources { connections });
|
||||
.push_back(Event::RefreshResources { connections });
|
||||
|
||||
self.next_dns_refresh = Some(now + DNS_REFRESH_INTERVAL);
|
||||
}
|
||||
@@ -864,6 +909,11 @@ impl ClientState {
|
||||
Some(_) => {}
|
||||
}
|
||||
|
||||
if self.next_system_resolver_refresh.is_some_and(|e| now >= e) {
|
||||
self.buffered_events.push_back(Event::RefreshInterfance);
|
||||
self.next_system_resolver_refresh = None;
|
||||
}
|
||||
|
||||
while let Some(event) = self.node.poll_event() {
|
||||
match event {
|
||||
snownet::Event::ConnectionFailed(id) => {
|
||||
@@ -872,18 +922,16 @@ impl ClientState {
|
||||
snownet::Event::SignalIceCandidate {
|
||||
connection,
|
||||
candidate,
|
||||
} => self
|
||||
.buffered_events
|
||||
.push_back(ClientEvent::SignalIceCandidate {
|
||||
conn_id: connection,
|
||||
candidate,
|
||||
}),
|
||||
} => self.buffered_events.push_back(Event::SignalIceCandidate {
|
||||
conn_id: connection,
|
||||
candidate,
|
||||
}),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_event(&mut self) -> Option<ClientEvent> {
|
||||
pub(crate) fn poll_event(&mut self) -> Option<Event> {
|
||||
self.buffered_events.pop_front()
|
||||
}
|
||||
|
||||
@@ -896,6 +944,10 @@ impl ClientState {
|
||||
}
|
||||
}
|
||||
|
||||
fn dns_updated(old_dns: &[IpAddr], new_dns: &[IpAddr]) -> bool {
|
||||
HashSet::<&IpAddr>::from_iter(old_dns.iter()) != HashSet::<&IpAddr>::from_iter(new_dns.iter())
|
||||
}
|
||||
|
||||
fn effective_dns_servers(
|
||||
upstream_dns: Vec<DnsServer>,
|
||||
default_resolvers: Vec<IpAddr>,
|
||||
@@ -1024,15 +1076,89 @@ impl IpProvider {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand_core::OsRng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ignores_ip4_igmp_multicast() {
|
||||
assert!(is_definitely_not_a_resource("224.0.0.22".parse().unwrap()))
|
||||
assert!(is_definitely_not_a_resource(ip("224.0.0.22")))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_ip6_multicast_all_routers() {
|
||||
assert!(is_definitely_not_a_resource("ff02::2".parse().unwrap()))
|
||||
assert!(is_definitely_not_a_resource(ip("ff02::2")))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_updated_when_dns_changes() {
|
||||
assert!(dns_updated(&[ip("1.0.0.1")], &[ip("1.1.1.1")]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_not_updated_when_dns_remains_the_same() {
|
||||
assert!(!dns_updated(&[ip("1.1.1.1")], &[ip("1.1.1.1")]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dns_updated_ignores_order() {
|
||||
assert!(!dns_updated(
|
||||
&[ip("1.0.0.1"), ip("1.1.1.1")],
|
||||
&[ip("1.1.1.1"), ip("1.0.0.1")]
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_system_dns_works() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
|
||||
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterfance));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_system_dns_without_change_is_a_no_op() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
client_state.poll_event();
|
||||
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
assert!(client_state.poll_event().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_system_dns_with_change_works() {
|
||||
let mut client_state = ClientState::for_test();
|
||||
|
||||
let now = Instant::now();
|
||||
client_state.update_system_resolvers(vec![ip("1.1.1.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
client_state.poll_event();
|
||||
|
||||
client_state.update_system_resolvers(vec![ip("1.0.0.1")], now);
|
||||
let now = now + Duration::from_millis(500);
|
||||
client_state.handle_timeout(now);
|
||||
assert_eq!(client_state.poll_event(), Some(Event::RefreshInterfance));
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
fn for_test() -> ClientState {
|
||||
ClientState::new(StaticSecret::random_from_rng(OsRng))
|
||||
}
|
||||
}
|
||||
|
||||
fn ip(addr: &str) -> IpAddr {
|
||||
addr.parse().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,6 +119,8 @@ impl Io {
|
||||
&mut self,
|
||||
dns_servers: impl IntoIterator<Item = (IpAddr, DnsServer)>,
|
||||
) {
|
||||
self.forwarded_dns_queries =
|
||||
FuturesTupleSet::new(Duration::from_secs(60), DNS_QUERIES_QUEUE_SIZE);
|
||||
self.upstream_dns_servers = create_resolvers(dns_servers);
|
||||
}
|
||||
|
||||
|
||||
@@ -76,8 +76,27 @@ where
|
||||
|
||||
pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<ClientEvent>> {
|
||||
loop {
|
||||
if let Some(other) = self.role_state.poll_event() {
|
||||
return Poll::Ready(Ok(other));
|
||||
match self.role_state.poll_event() {
|
||||
Some(client::Event::RefreshInterfance) => {
|
||||
self.update_interface()?;
|
||||
continue;
|
||||
}
|
||||
Some(client::Event::SignalIceCandidate { conn_id, candidate }) => {
|
||||
return Poll::Ready(Ok(ClientEvent::SignalIceCandidate { conn_id, candidate }))
|
||||
}
|
||||
Some(client::Event::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
}) => {
|
||||
return Poll::Ready(Ok(ClientEvent::ConnectionIntent {
|
||||
resource,
|
||||
connected_gateway_ids,
|
||||
}))
|
||||
}
|
||||
Some(client::Event::RefreshResources { connections }) => {
|
||||
return Poll::Ready(Ok(ClientEvent::RefreshResources { connections }))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let Some(packet) = self.role_state.poll_packets() {
|
||||
@@ -239,6 +258,7 @@ where
|
||||
Ok(io)
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ClientEvent {
|
||||
SignalIceCandidate {
|
||||
conn_id: GatewayId,
|
||||
|
||||
@@ -133,6 +133,7 @@ impl PacketTransformClient {
|
||||
}
|
||||
|
||||
pub fn set_dns(&mut self, mapping: BiMap<IpAddr, DnsServer>) {
|
||||
self.mangled_dns_ids.clear();
|
||||
self.dns_mapping = mapping;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use arc_swap::ArcSwap;
|
||||
use connlib_client_shared::{file_logger, ResourceDescription};
|
||||
use connlib_shared::{keypair, messages::ResourceId, LoginUrl, BUNDLE_ID};
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use std::{net::IpAddr, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
||||
use std::{path::PathBuf, str::FromStr, sync::Arc, time::Duration};
|
||||
use system_tray_menu::Event as TrayMenuEvent;
|
||||
use tauri::{Manager, SystemTray, SystemTrayEvent};
|
||||
use tokio::sync::{mpsc, oneshot, Notify};
|
||||
@@ -471,18 +471,6 @@ impl connlib_client_shared::Callbacks for CallbackHandler {
|
||||
self.notify_controller.notify_one();
|
||||
}
|
||||
|
||||
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
|
||||
let resolvers = match client::resolvers::get() {
|
||||
Ok(resolvers) => resolvers,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get system default resolvers: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some(resolvers)
|
||||
}
|
||||
|
||||
fn roll_log_file(&self) -> Option<PathBuf> {
|
||||
self.logger.roll_to_new_file().unwrap_or_else(|e| {
|
||||
tracing::debug!("Failed to roll over to new file: {e}");
|
||||
@@ -554,6 +542,8 @@ impl Controller {
|
||||
tokio::runtime::Handle::current(),
|
||||
)?;
|
||||
|
||||
connlib.set_dns(client::resolvers::get().unwrap_or_default());
|
||||
|
||||
self.session = Some(Session {
|
||||
callback_handler,
|
||||
connlib,
|
||||
|
||||
@@ -19,11 +19,7 @@ async fn main() -> Result<()> {
|
||||
let (layer, handle) = cli.log_dir.as_deref().map(file_logger::layer).unzip();
|
||||
setup_global_subscriber(layer);
|
||||
|
||||
let dns_control_method = get_dns_control_from_env();
|
||||
let callbacks = CallbackHandler {
|
||||
dns_control_method: dns_control_method.clone(),
|
||||
handle,
|
||||
};
|
||||
let callbacks = CallbackHandler { handle };
|
||||
|
||||
// AKA "Device ID", not the Firezone slug
|
||||
let firezone_id = match cli.firezone_id {
|
||||
@@ -40,15 +36,17 @@ async fn main() -> Result<()> {
|
||||
public_key.to_bytes(),
|
||||
)?;
|
||||
|
||||
let mut session = Session::connect(
|
||||
let session = Session::connect(
|
||||
login,
|
||||
private_key,
|
||||
None,
|
||||
callbacks,
|
||||
callbacks.clone(),
|
||||
max_partition_time,
|
||||
tokio::runtime::Handle::current(),
|
||||
)
|
||||
.unwrap();
|
||||
// TODO: this should be added dynamically
|
||||
session.set_dns(system_resolvers(get_dns_control_from_env()).unwrap_or_default());
|
||||
|
||||
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
|
||||
let mut sighup = tokio::signal::unix::signal(SignalKind::hangup())?;
|
||||
@@ -76,37 +74,21 @@ async fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn system_resolvers(dns_control_method: Option<DnsControlMethod>) -> Result<Vec<IpAddr>> {
|
||||
match dns_control_method {
|
||||
None => get_system_default_resolvers_resolv_conf(),
|
||||
Some(DnsControlMethod::EtcResolvConf) => get_system_default_resolvers_resolv_conf(),
|
||||
Some(DnsControlMethod::NetworkManager) => get_system_default_resolvers_network_manager(),
|
||||
Some(DnsControlMethod::Systemd) => get_system_default_resolvers_systemd_resolved(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CallbackHandler {
|
||||
dns_control_method: Option<DnsControlMethod>,
|
||||
handle: Option<file_logger::Handle>,
|
||||
}
|
||||
|
||||
impl Callbacks for CallbackHandler {
|
||||
/// May return Firezone's own servers, e.g. `100.100.111.1`.
|
||||
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
|
||||
let maybe_resolvers = match self.dns_control_method {
|
||||
None => get_system_default_resolvers_resolv_conf(),
|
||||
Some(DnsControlMethod::EtcResolvConf) => get_system_default_resolvers_resolv_conf(),
|
||||
Some(DnsControlMethod::NetworkManager) => {
|
||||
get_system_default_resolvers_network_manager()
|
||||
}
|
||||
Some(DnsControlMethod::Systemd) => get_system_default_resolvers_systemd_resolved(),
|
||||
};
|
||||
|
||||
let resolvers = match maybe_resolvers {
|
||||
Ok(resolvers) => resolvers,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to get system default resolvers: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(?resolvers);
|
||||
|
||||
Some(resolvers)
|
||||
}
|
||||
|
||||
fn on_disconnect(&self, error: &connlib_client_shared::Error) {
|
||||
tracing::error!("Disconnected: {error}");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user