diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 289ca7495..7631f7456 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -7,7 +7,7 @@ use crate::{ }; use anyhow::Result; use connlib_shared::{ - messages::{ConnectionAccepted, GatewayId, GatewayResponse, ResourceAccepted, ResourceId}, + messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId}, Callbacks, }; use firezone_tunnel::ClientTunnel; @@ -80,7 +80,7 @@ where continue; } Poll::Ready(Err(e)) => { - tracing::error!("Tunnel failed: {e}"); + tracing::warn!("Tunnel error: {e}"); continue; } Poll::Pending => {} @@ -104,9 +104,9 @@ where } } - fn handle_tunnel_event(&mut self, event: firezone_tunnel::Event) { + fn handle_tunnel_event(&mut self, event: firezone_tunnel::ClientEvent) { match event { - firezone_tunnel::Event::SignalIceCandidate { + firezone_tunnel::ClientEvent::SignalIceCandidate { conn_id: gateway, candidate, } => { @@ -120,7 +120,7 @@ where }), ); } - firezone_tunnel::Event::ConnectionIntent { + firezone_tunnel::ClientEvent::ConnectionIntent { connected_gateway_ids, resource, .. @@ -134,15 +134,12 @@ where ); self.connection_intents.register_new_intent(id, resource); } - firezone_tunnel::Event::RefreshResources { connections } => { + firezone_tunnel::ClientEvent::RefreshResources { connections } => { for connection in connections { self.portal .send(PHOENIX_TOPIC, EgressMessages::ReuseConnection(connection)); } } - firezone_tunnel::Event::SendPacket { .. } | firezone_tunnel::Event::StopPeer { .. } => { - unreachable!("Handled internally") - } } } diff --git a/rust/connlib/shared/src/error.rs b/rust/connlib/shared/src/error.rs index 08961588e..47217f6a8 100644 --- a/rust/connlib/shared/src/error.rs +++ b/rust/connlib/shared/src/error.rs @@ -170,10 +170,6 @@ pub enum ConnlibError { Snownet(#[from] snownet::Error), #[error("Detected non-allowed packet in channel")] UnallowedPacket, - #[error("No available ipv4 socket")] - NoIpv4, - #[error("No available ipv6 socket")] - NoIpv6, // Error variants for `systemd-resolved` DNS control #[error("Failed to control system DNS with `resolvectl`")] diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index a8afd2003..c60b8b7cc 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,7 +1,7 @@ use crate::ip_packet::{IpPacket, MutableIpPacket}; use crate::peer::PacketTransformClient; use crate::peer_store::PeerStore; -use crate::{dns, dns::DnsQuery, Event, Tunnel, DNS_QUERIES_QUEUE_SIZE}; +use crate::{dns, dns::DnsQuery, Tunnel}; use bimap::BiMap; use connlib_shared::error::{ConnlibError as Error, ConnlibError}; use connlib_shared::messages::{ @@ -10,22 +10,18 @@ use connlib_shared::messages::{ }; use connlib_shared::{Callbacks, Dname, IpProvider}; use domain::base::Rtype; -use futures_bounded::FuturesTupleSet; use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use itertools::Itertools; use snownet::Client; -use hickory_resolver::config::{NameServerConfig, Protocol, ResolverConfig}; -use hickory_resolver::TokioAsyncResolver; +use crate::ClientEvent; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet, VecDeque}; use std::iter; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str::FromStr; -use std::task::{Context, Poll}; use std::time::{Duration, Instant}; -use tokio::time::{Interval, MissedTickBehavior}; // Using str here because Ipv4/6Network doesn't support `const` 🙃 const IPV4_RESOURCES: &str = "100.96.0.0/11"; @@ -35,6 +31,11 @@ const DNS_PORT: u16 = 53; const DNS_SENTINELS_V4: &str = "100.100.111.0/24"; const DNS_SENTINELS_V6: &str = "fd00:2021:1111:8000:100:100:111:0/120"; +// With this single timer this might mean that some DNS are refreshed too often +// however... this also mean any resource is refresh within a 5 mins interval +// therefore, only the first time it's added that happens, after that it doesn't matter. +const DNS_REFRESH_INTERVAL: Duration = Duration::from_secs(300); + #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct DnsResource { pub id: ResourceId, @@ -182,17 +183,21 @@ where let dns_mapping = sentinel_dns_mapping(&effective_dns_servers); self.role_state.set_dns_mapping(dns_mapping.clone()); + self.io.set_upstream_dns_servers(dns_mapping.clone()); - self.device.initialize( + let callbacks = self.callbacks().clone(); + + self.io.device_mut().initialize( config, // We can just sort in here because sentinel ips are created in order dns_mapping.left_values().copied().sorted().collect(), - &self.callbacks().clone(), + &callbacks, )?; - self.device + self.io + .device_mut() .set_routes(self.role_state.routes().collect(), &self.callbacks)?; - let name = self.device.name().to_owned(); + let name = self.io.device_mut().name().to_owned(); self.callbacks.on_tunnel_ready()?; @@ -210,15 +215,15 @@ where #[tracing::instrument(level = "trace", skip(self))] pub fn update_routes(&mut self) -> connlib_shared::Result<()> { - self.device + self.io + .device_mut() .set_routes(self.role_state.routes().collect(), &self.callbacks)?; Ok(()) } pub fn add_ice_candidate(&mut self, conn_id: GatewayId, ice_candidate: String) { - self.connections_state - .node + self.node .add_remote_candidate(conn_id, ice_candidate, Instant::now()); } } @@ -236,20 +241,18 @@ pub struct ClientState { pub peers: PeerStore>, - forwarded_dns_queries: FuturesTupleSet< - Result, - DnsQuery<'static>, - >, - pub ip_provider: IpProvider, - refresh_dns_timer: Interval, - dns_mapping: BiMap, - dns_resolvers: HashMap, - buffered_events: VecDeque>, + buffered_events: VecDeque, interface_config: Option, + buffered_packets: VecDeque>, + + /// DNS queries that we need to forward to the system resolver. + buffered_dns_queries: VecDeque>, + + next_dns_refresh: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -267,8 +270,7 @@ impl ClientState { ) -> Option<(GatewayId, MutableIpPacket<'a>)> { let (packet, dest) = match self.handle_dns(packet, now) { Ok(response) => { - self.buffered_events - .push_back(Event::SendPacket(response?.to_owned())); + self.buffered_packets.push_back(response?.to_owned()); return None; } Err(non_dns_packet) => non_dns_packet, @@ -315,7 +317,7 @@ impl ClientState { } } - self.add_pending_dns_query(query); + self.buffered_dns_queries.push_back(query.into_owned()); Ok(None) } @@ -453,10 +455,11 @@ impl ClientState { tracing::debug!("Sending connection intent"); - self.buffered_events.push_back(Event::ConnectionIntent { - resource, - connected_gateway_ids: gateways, - }); + self.buffered_events + .push_back(ClientEvent::ConnectionIntent { + resource, + connected_gateway_ids: gateways, + }); } pub fn create_peer_config_for_new_connection( @@ -483,7 +486,6 @@ impl ClientState { fn set_dns_mapping(&mut self, mapping: BiMap) { self.dns_mapping = mapping.clone(); - self.dns_resolvers = create_resolvers(mapping); } pub fn dns_mapping(&self) -> BiMap { @@ -548,39 +550,21 @@ impl ClientState { .map(|(_, res)| res.id) } - fn add_pending_dns_query(&mut self, query: DnsQuery) { - let upstream = query.query.destination(); - let Some(resolver) = self.dns_resolvers.get(&upstream).cloned() else { - tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server"); - return; - }; - - let query = query.into_owned(); - - if self - .forwarded_dns_queries - .try_push( - { - let name = query.name.clone(); - let record_type = query.record_type; - - async move { resolver.lookup(&name, record_type).await } - }, - query, - ) - .is_err() - { - tracing::warn!("Too many DNS queries, dropping existing one"); - } + pub fn poll_packets(&mut self) -> Option> { + self.buffered_packets.pop_front() } - pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - if let Some(event) = self.buffered_events.pop_front() { - return Poll::Ready(event); - } + pub fn poll_dns_queries(&mut self) -> Option> { + self.buffered_dns_queries.pop_front() + } - if self.refresh_dns_timer.poll_tick(cx).is_ready() { + pub fn poll_timeout(&self) -> Option { + self.next_dns_refresh + } + + pub fn handle_timeout(&mut self, now: Instant) { + match self.next_dns_refresh { + Some(next_dns_refresh) if now >= next_dns_refresh => { let mut connections = Vec::new(); self.peers @@ -602,63 +586,27 @@ impl ClientState { payload: Some(resource.address.clone()), }); } - return Poll::Ready(Event::RefreshResources { connections }); - } - match self.forwarded_dns_queries.poll_unpin(cx) { - Poll::Ready((Ok(response), query)) => { - match dns::build_response_from_resolve_result(query.query, response) { - Ok(Some(packet)) => return Poll::Ready(Event::SendPacket(packet)), - Ok(None) => continue, - Err(e) => { - tracing::warn!("Failed to build DNS response from lookup result: {e}"); - continue; - } - } - } - Poll::Ready((Err(resolve_timeout), query)) => { - tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}"); - continue; - } - Poll::Pending => {} - } + self.buffered_events + .push_back(ClientEvent::RefreshResources { connections }); - return Poll::Pending; + self.next_dns_refresh = Some(now + DNS_REFRESH_INTERVAL); + } + None => self.next_dns_refresh = Some(now + DNS_REFRESH_INTERVAL), + Some(_) => {} } } -} -fn create_resolvers( - sentinel_mapping: BiMap, -) -> HashMap { - sentinel_mapping - .into_iter() - .map(|(sentinel, srv)| { - let mut resolver_config = ResolverConfig::new(); - resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); - ( - sentinel, - TokioAsyncResolver::tokio(resolver_config, Default::default()), - ) - }) - .collect() + pub fn poll_event(&mut self) -> Option { + self.buffered_events.pop_front() + } } impl Default for ClientState { fn default() -> Self { - // With this single timer this might mean that some DNS are refreshed too often - // however... this also mean any resource is refresh within a 5 mins interval - // therefore, only the first time it's added that happens, after that it doesn't matter. - let mut interval = tokio::time::interval(Duration::from_secs(300)); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - Self { awaiting_connection: Default::default(), resources_gateways: Default::default(), - forwarded_dns_queries: FuturesTupleSet::new( - Duration::from_secs(60), - DNS_QUERIES_QUEUE_SIZE, - ), ip_provider: IpProvider::new( IPV4_RESOURCES.parse().unwrap(), IPV6_RESOURCES.parse().unwrap(), @@ -669,11 +617,12 @@ impl Default for ClientState { resource_ids: Default::default(), peers: Default::default(), deferred_dns_queries: Default::default(), - refresh_dns_timer: interval, dns_mapping: Default::default(), - dns_resolvers: Default::default(), buffered_events: Default::default(), interface_config: Default::default(), + buffered_packets: Default::default(), + buffered_dns_queries: Default::default(), + next_dns_refresh: Default::default(), } } } @@ -761,4 +710,35 @@ mod tests { fn ignores_ip6_multicast_all_routers() { assert!(is_definitely_not_a_resource("ff02::2".parse().unwrap())) } + + #[test] + fn initial_poll_timeout_is_none() { + let state = ClientState::default(); + + assert!(state.poll_timeout().is_none()) + } + + #[test] + fn first_timeout_is_after_dns_refresh_interval() { + let start = Instant::now(); + let mut state = ClientState::default(); + + state.handle_timeout(start); + + assert_eq!(state.poll_timeout().unwrap(), start + DNS_REFRESH_INTERVAL) + } + + #[test] + fn does_not_advance_time_before_timeout() { + let start = Instant::now(); + let mut state = ClientState::default(); + + state.handle_timeout(start); + + let before = state.poll_timeout().unwrap(); + state.handle_timeout(start + DNS_REFRESH_INTERVAL / 2); + let after = state.poll_timeout().unwrap(); + + assert_eq!(before, after) + } } diff --git a/rust/connlib/tunnel/src/control_protocol/client.rs b/rust/connlib/tunnel/src/control_protocol/client.rs index c1f567296..81975dc05 100644 --- a/rust/connlib/tunnel/src/control_protocol/client.rs +++ b/rust/connlib/tunnel/src/control_protocol/client.rs @@ -64,7 +64,7 @@ where return Ok(Request::ReuseConnection(connection)); } - if self.connections_state.node.is_expecting_answer(gateway_id) { + if self.node.is_expecting_answer(gateway_id) { return Err(Error::PendingConnection); } @@ -73,14 +73,10 @@ where .get_awaiting_connection(&resource_id)? .clone(); - let offer = self.connections_state.node.new_connection( + let offer = self.node.new_connection( gateway_id, - stun(&relays, |addr| { - self.connections_state.sockets.can_handle(addr) - }), - turn(&relays, |addr| { - self.connections_state.sockets.can_handle(addr) - }), + stun(&relays, |addr| self.io.sockets_ref().can_handle(addr)), + turn(&relays, |addr| self.io.sockets_ref().can_handle(addr)), awaiting_connection.last_intent_sent_at, Instant::now(), ); @@ -147,7 +143,7 @@ where .gateway_by_resource(&resource_id) .ok_or(Error::UnknownResource)?; - self.connections_state.node.accept_answer( + self.node.accept_answer( gateway_id, gateway_public_key, snownet::Answer { @@ -209,7 +205,7 @@ where send_dns_answer( &mut self.role_state, Rtype::Aaaa, - &self.device, + self.io.device_mut(), &resource_description, &addrs, ); @@ -217,7 +213,7 @@ where send_dns_answer( &mut self.role_state, Rtype::A, - &self.device, + self.io.device_mut(), &resource_description, &addrs, ); diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 5efaa8a2f..cea6c61e3 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -14,13 +14,13 @@ use ip_network::IpNetwork; use secrecy::{ExposeSecret as _, Secret}; use snownet::Server; use std::collections::HashSet; -use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; -use tokio::time::{interval, Interval, MissedTickBehavior}; const PEERS_IPV4: &str = "100.64.0.0/11"; const PEERS_IPV6: &str = "fd00:2021:1111::/107"; +const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1); + /// Description of a resource that maps to a DNS record which had its domain already resolved. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ResolvedResourceDescriptionDns { @@ -46,14 +46,16 @@ where #[tracing::instrument(level = "trace", skip(self))] pub fn set_interface(&mut self, config: &InterfaceConfig) -> connlib_shared::Result<()> { // Note: the dns fallback strategy is irrelevant for gateways - self.device - .initialize(config, vec![], &self.callbacks().clone())?; - self.device.set_routes( + let callbacks = self.callbacks().clone(); + self.io + .device_mut() + .initialize(config, vec![], &callbacks)?; + self.io.device_mut().set_routes( HashSet::from([PEERS_IPV4.parse().unwrap(), PEERS_IPV6.parse().unwrap()]), - &self.callbacks, + &callbacks, )?; - let name = self.device.name().to_owned(); + let name = self.io.device_mut().name().to_owned(); tracing::debug!(ip4 = %config.ipv4, ip6 = %config.ipv6, %name, "TUN device initialized"); @@ -95,7 +97,7 @@ where ResourceDescription::Cidr(ref cidr) => vec![cidr.address], }; - let answer = self.connections_state.node.accept_connection( + let answer = self.node.accept_connection( client_id, snownet::Offer { session_key: key.expose_secret().0.into(), @@ -105,12 +107,8 @@ where }, }, client, - stun(&relays, |addr| { - self.connections_state.sockets.can_handle(addr) - }), - turn(&relays, |addr| { - self.connections_state.sockets.can_handle(addr) - }), + stun(&relays, |addr| self.io.sockets_ref().can_handle(addr)), + turn(&relays, |addr| self.io.sockets_ref().can_handle(addr)), Instant::now(), ); @@ -195,8 +193,7 @@ where } pub fn add_ice_candidate(&mut self, conn_id: ClientId, ice_candidate: String) { - self.connections_state - .node + self.node .add_remote_candidate(conn_id, ice_candidate, Instant::now()); } @@ -222,9 +219,10 @@ where } /// [`Tunnel`] state specific to gateways. +#[derive(Default)] pub struct GatewayState { pub peers: PeerStore, - expire_interval: Interval, + next_expiry_resources_check: Option, } impl GatewayState { @@ -240,27 +238,62 @@ impl GatewayState { Some((peer.conn_id, packet)) } - pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { - ready!(self.expire_interval.poll_tick(cx)); - self.expire_resources(); - Poll::Ready(()) + pub fn poll_timeout(&self) -> Option { + // TODO: This should check when the next resource actually expires instead of doing it at a fixed interval. + self.next_expiry_resources_check } - fn expire_resources(&mut self) { - self.peers - .iter_mut() - .for_each(|p| p.transform.expire_resources()); - self.peers.retain(|_, p| !p.transform.is_emptied()); - } -} + pub fn handle_timeout(&mut self, now: Instant) { + match self.next_expiry_resources_check { + Some(next_expiry_resources_check) if now >= next_expiry_resources_check => { + self.peers + .iter_mut() + .for_each(|p| p.transform.expire_resources()); + self.peers.retain(|_, p| !p.transform.is_emptied()); -impl Default for GatewayState { - fn default() -> Self { - let mut expire_interval = interval(Duration::from_secs(1)); - expire_interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - Self { - peers: Default::default(), - expire_interval, + self.next_expiry_resources_check = Some(now + EXPIRE_RESOURCES_INTERVAL); + } + None => self.next_expiry_resources_check = Some(now + EXPIRE_RESOURCES_INTERVAL), + Some(_) => {} } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn initial_poll_timeout_is_none() { + let state = GatewayState::default(); + + assert!(state.poll_timeout().is_none()) + } + + #[test] + fn first_timeout_is_after_expire_resources_interval() { + let start = Instant::now(); + let mut state = GatewayState::default(); + + state.handle_timeout(start); + + assert_eq!( + state.poll_timeout().unwrap(), + start + EXPIRE_RESOURCES_INTERVAL + ) + } + + #[test] + fn does_not_advance_time_before_timeout() { + let start = Instant::now(); + let mut state = GatewayState::default(); + + state.handle_timeout(start); + + let before = state.poll_timeout().unwrap(); + state.handle_timeout(start + EXPIRE_RESOURCES_INTERVAL / 2); + let after = state.poll_timeout().unwrap(); + + assert_eq!(before, after) + } +} diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs new file mode 100644 index 000000000..4720ba4b0 --- /dev/null +++ b/rust/connlib/tunnel/src/io.rs @@ -0,0 +1,189 @@ +use crate::{ + device_channel::Device, + dns::{self, DnsQuery}, + ip_packet::{IpPacket, MutableIpPacket}, + sockets::{Received, Sockets}, +}; +use connlib_shared::messages::DnsServer; +use futures_bounded::FuturesTupleSet; +use futures_util::FutureExt as _; +use hickory_resolver::{ + config::{NameServerConfig, Protocol, ResolverConfig}, + TokioAsyncResolver, +}; +use snownet::Transmit; +use std::{ + collections::HashMap, + io, + net::IpAddr, + pin::Pin, + task::{ready, Context, Poll}, + time::{Duration, Instant}, +}; + +const DNS_QUERIES_QUEUE_SIZE: usize = 100; + +pub struct Io { + device: Device, + timeout: Option>>, + sockets: Sockets, + + upstream_dns_servers: HashMap, + forwarded_dns_queries: FuturesTupleSet< + Result, + DnsQuery<'static>, + >, +} + +pub enum Input<'a, I> { + Timeout(Instant), + Device(MutableIpPacket<'a>), + Network(I), +} + +impl Io { + pub fn new() -> io::Result { + Ok(Self { + device: Device::new(), + timeout: None, + sockets: Sockets::new()?, + upstream_dns_servers: HashMap::default(), + forwarded_dns_queries: FuturesTupleSet::new( + Duration::from_secs(60), + DNS_QUERIES_QUEUE_SIZE, + ), + }) + } + + pub fn poll<'b>( + &mut self, + cx: &mut Context<'_>, + ip4_buffer: &'b mut [u8], + ip6_bffer: &'b mut [u8], + device_buffer: &'b mut [u8], + ) -> Poll>>>> { + loop { + // FIXME: Building the DNS response in here isn't very clean because this should only be the IO component and not do business-logic. + // But it also seems weird to pass the DNS result out if we've got the device right here. + match self.forwarded_dns_queries.poll_unpin(cx) { + Poll::Ready((Ok(response), query)) => { + match dns::build_response_from_resolve_result(query.query, response) { + Ok(Some(packet)) => { + self.device.write(packet)?; + } + Ok(None) => {} + Err(e) => { + tracing::warn!("Failed to build DNS response from lookup result: {e}"); + } + } + + continue; + } + Poll::Ready((Err(resolve_timeout), query)) => { + tracing::warn!(name = %query.name, server = %query.query.destination(), "DNS query timed out: {resolve_timeout}"); + continue; + } + Poll::Pending => {} + } + + if let Some(timeout) = self.timeout.as_mut() { + if timeout.poll_unpin(cx).is_ready() { + return Poll::Ready(Ok(Input::Timeout(timeout.deadline().into()))); + } + } + + if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { + return Poll::Ready(Ok(Input::Network(network))); + } + + ready!(self.sockets.poll_send_ready(cx))?; // Packets read from the device need to be written to a socket, let's make sure the socket can take more packets. + + if let Poll::Ready(packet) = self.device.poll_read(device_buffer, cx)? { + return Poll::Ready(Ok(Input::Device(packet))); + } + + return Poll::Pending; + } + } + + pub fn device_mut(&mut self) -> &mut Device { + &mut self.device + } + + pub fn sockets_ref(&self) -> &Sockets { + &self.sockets + } + + pub fn set_upstream_dns_servers( + &mut self, + dns_servers: impl IntoIterator, + ) { + self.upstream_dns_servers = create_resolvers(dns_servers); + } + + pub fn perform_dns_query(&mut self, query: DnsQuery<'static>) { + let upstream = query.query.destination(); + let Some(resolver) = self.upstream_dns_servers.get(&upstream).cloned() else { + tracing::warn!(%upstream, "Dropping DNS query because of unknown upstream DNS server"); + return; + }; + + let query = query.into_owned(); + + if self + .forwarded_dns_queries + .try_push( + { + let name = query.name.clone(); + let record_type = query.record_type; + + async move { resolver.lookup(&name, record_type).await } + }, + query, + ) + .is_err() + { + tracing::warn!("Too many DNS queries, dropping existing one"); + } + } + + pub fn reset_timeout(&mut self, timeout: Instant) { + let timeout = tokio::time::Instant::from_std(timeout); + + match self.timeout.as_mut() { + Some(existing_timeout) if existing_timeout.deadline() != timeout => { + existing_timeout.as_mut().reset(timeout) + } + Some(_) => {} + None => self.timeout = Some(Box::pin(tokio::time::sleep_until(timeout))), + } + } + + pub fn send_network(&self, transmit: Transmit) -> io::Result<()> { + self.sockets.try_send(&transmit)?; + + Ok(()) + } + + pub fn send_device(&self, packet: IpPacket<'_>) -> io::Result<()> { + self.device.write(packet)?; + + Ok(()) + } +} + +fn create_resolvers( + dns_servers: impl IntoIterator, +) -> HashMap { + dns_servers + .into_iter() + .map(|(sentinel, srv)| { + let mut resolver_config = ResolverConfig::new(); + resolver_config.add_name_server(NameServerConfig::new(srv.address(), Protocol::Udp)); + ( + sentinel, + TokioAsyncResolver::tokio(resolver_config, Default::default()), + ) + }) + .collect() +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 55cd241c1..1513966c1 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -8,28 +8,26 @@ use connlib_shared::{ messages::{ClientId, GatewayId, ResourceId, ReuseConnection}, CallbackErrorFacade, Callbacks, Error, Result, }; -use device_channel::Device; -use futures_util::FutureExt; -use peer::PacketTransform; -use peer_store::PeerStore; use snownet::{Node, Server}; -use sockets::{Received, Sockets}; +use sockets::Received; use std::{ collections::HashSet, fmt, hash::Hash, - io, - pin::Pin, - task::{ready, Context, Poll}, + task::{Context, Poll}, time::{Duration, Instant}, }; pub use client::ClientState; pub use control_protocol::client::Request; pub use gateway::{GatewayState, ResolvedResourceDescriptionDns}; -use ip_packet::IpPacket; +use io::Io; +use stats::Stats; +use utils::earliest; mod client; +mod io; +mod stats; mod control_protocol { pub mod client; } @@ -43,7 +41,6 @@ mod sockets; mod utils; const MAX_UDP_SIZE: usize = (1 << 16) - 1; -const DNS_QUERIES_QUEUE_SIZE: usize = 100; const REALM: &str = "firezone"; @@ -59,12 +56,15 @@ pub struct Tunnel { /// State that differs per role, i.e. clients vs gateways. role_state: TRoleState, + node: Node, - device: Device, + io: Io, + stats: Stats, - connections_state: ConnectionState, - - read_buf: Box<[u8; MAX_UDP_SIZE]>, + write_buf: Box<[u8; MAX_UDP_SIZE]>, + ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, + ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, + device_read_buf: Box<[u8; MAX_UDP_SIZE]>, } impl Tunnel @@ -72,62 +72,139 @@ where CB: Callbacks + 'static, { pub fn reconnect(&mut self) { - self.connections_state.node.reconnect(Instant::now()); + self.node.reconnect(Instant::now()); } - pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>> { - match self.role_state.poll_next_event(cx) { - Poll::Ready(Event::SendPacket(packet)) => { - self.device.write(packet)?; - cx.waker().wake_by_ref(); + pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(other) = self.role_state.poll_event() { + return Poll::Ready(Ok(other)); } - Poll::Ready(other) => return Poll::Ready(Ok(other)), - _ => (), - } - match self.connections_state.poll_next_event(cx) { - Poll::Ready(Event::StopPeer(id)) => { - self.role_state.cleanup_connected_gateway(&id); - cx.waker().wake_by_ref(); + if let Some(packet) = self.role_state.poll_packets() { + self.io.send_device(packet)?; + continue; } - Poll::Ready(other) => return Poll::Ready(Ok(other)), - _ => (), - } - match self.connections_state.poll_sockets( - &mut self.device, - &mut self.role_state.peers, - cx, - )? { - Poll::Ready(()) => { - cx.waker().wake_by_ref(); + if let Some(transmit) = self.node.poll_transmit() { + self.io.send_network(transmit)?; + continue; } - Poll::Pending => {} - } - ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device. - - match self.device.poll_read(self.read_buf.as_mut(), cx)? { - Poll::Ready(packet) => { - let Some((peer_id, packet)) = self.role_state.encapsulate(packet, Instant::now()) - else { - cx.waker().wake_by_ref(); - return Poll::Pending; - }; - - self.connections_state.send(peer_id, packet.as_immutable()); - - cx.waker().wake_by_ref(); + if let Some(dns_query) = self.role_state.poll_dns_queries() { + self.io.perform_dns_query(dns_query); + continue; } - Poll::Pending => {} - } - // After any state change, check what the new timeout is and reset it if necessary. - if self.connections_state.poll_timeout(cx).is_ready() { - cx.waker().wake_by_ref() - } + if let Some(event) = self.node.poll_event() { + match event { + snownet::Event::ConnectionFailed(id) => { + self.role_state.cleanup_connected_gateway(&id); + } + snownet::Event::SignalIceCandidate { + connection, + candidate, + } => { + return Poll::Ready(Ok(ClientEvent::SignalIceCandidate { + conn_id: connection, + candidate, + })); + } + _ => {} + } - Poll::Pending + continue; + } + + if let Some(timeout) = + earliest(self.node.poll_timeout(), self.role_state.poll_timeout()) + { + self.io.reset_timeout(timeout); + } + + match self.io.poll( + cx, + self.ip4_read_buf.as_mut(), + self.ip6_read_buf.as_mut(), + self.device_read_buf.as_mut(), + )? { + Poll::Ready(io::Input::Timeout(timeout)) => { + self.role_state.handle_timeout(timeout); + self.node.handle_timeout(timeout); + continue; + } + Poll::Ready(io::Input::Device(packet)) => { + let Some((peer_id, packet)) = + self.role_state.encapsulate(packet, Instant::now()) + else { + continue; + }; + + if let Some(transmit) = self.node.encapsulate( + peer_id, + packet.as_immutable().into(), + Instant::now(), + )? { + self.io.send_network(transmit)?; + } + + continue; + } + Poll::Ready(io::Input::Network(packets)) => { + for received in packets { + let Received { + local, + from, + packet, + } = received; + + let (conn_id, packet) = match self.node.decapsulate( + local, + from, + packet.as_ref(), + std::time::Instant::now(), + self.write_buf.as_mut(), + ) { + Ok(Some(packet)) => packet, + Ok(None) => { + continue; + } + Err(e) => { + tracing::warn!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"); + + continue; + } + }; + + let Some(peer) = self.role_state.peers.get_mut(&conn_id) else { + tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); + + continue; + }; + + let packet = match peer.untransform(packet.into()) { + Ok(packet) => packet, + Err(e) => { + tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}"); + + continue; + } + }; + + self.io.device_mut().write(packet.as_immutable())?; + } + + continue; + } + Poll::Pending => {} + } + + if self.stats.poll(&self.node, cx).is_ready() { + continue; + } + + return Poll::Pending; + } } } @@ -135,58 +212,120 @@ impl Tunnel where CB: Callbacks + 'static, { - pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll>> { - match self.role_state.poll(cx) { - Poll::Ready(()) => { - cx.waker().wake_by_ref(); + pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(transmit) = self.node.poll_transmit() { + self.io.send_network(transmit)?; + continue; } - Poll::Pending => {} - } - match self.connections_state.poll_next_event(cx) { - Poll::Ready(Event::StopPeer(id)) => { - self.role_state.peers.remove(&id); - cx.waker().wake_by_ref(); + if let Some(event) = self.node.poll_event() { + match event { + snownet::Event::ConnectionFailed(id) => { + self.role_state.peers.remove(&id); + } + snownet::Event::SignalIceCandidate { + connection, + candidate, + } => { + return Poll::Ready(Ok(GatewayEvent::SignalIceCandidate { + conn_id: connection, + candidate, + })); + } + _ => {} + } + + continue; } - Poll::Ready(other) => return Poll::Ready(Ok(other)), - _ => (), - } - match self.connections_state.poll_sockets( - &mut self.device, - &mut self.role_state.peers, - cx, - )? { - Poll::Ready(()) => { - cx.waker().wake_by_ref(); + if let Some(timeout) = + earliest(self.node.poll_timeout(), self.role_state.poll_timeout()) + { + self.io.reset_timeout(timeout); } - Poll::Pending => {} - } - ready!(self.connections_state.sockets.poll_send_ready(cx))?; // Ensure socket is ready before we read from device. + match self.io.poll( + cx, + self.ip4_read_buf.as_mut(), + self.ip6_read_buf.as_mut(), + self.device_read_buf.as_mut(), + )? { + Poll::Ready(io::Input::Timeout(timeout)) => { + self.role_state.handle_timeout(timeout); + self.node.handle_timeout(timeout); + continue; + } + Poll::Ready(io::Input::Device(packet)) => { + let Some((peer_id, packet)) = self.role_state.encapsulate(packet) else { + continue; + }; - match self.device.poll_read(self.read_buf.as_mut(), cx)? { - Poll::Ready(packet) => { - let Some((peer_id, packet)) = self.role_state.encapsulate(packet) else { - cx.waker().wake_by_ref(); - return Poll::Pending; - }; + if let Some(transmit) = self.node.encapsulate( + peer_id, + packet.as_immutable().into(), + Instant::now(), + )? { + self.io.send_network(transmit)?; + } - self.connections_state.send(peer_id, packet.as_immutable()); + continue; + } + Poll::Ready(io::Input::Network(packets)) => { + for received in packets { + let Received { + local, + from, + packet, + } = received; - cx.waker().wake_by_ref(); + let (conn_id, packet) = match self.node.decapsulate( + local, + from, + packet.as_ref(), + std::time::Instant::now(), + self.write_buf.as_mut(), + ) { + Ok(Some(packet)) => packet, + Ok(None) => { + continue; + } + Err(e) => { + tracing::warn!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"); + + continue; + } + }; + + let Some(peer) = self.role_state.peers.get_mut(&conn_id) else { + tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); + + continue; + }; + + let packet = match peer.untransform(packet.into()) { + Ok(packet) => packet, + Err(e) => { + tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}"); + + continue; + } + }; + + self.io.device_mut().write(packet.as_immutable())?; + } + + continue; + } + Poll::Pending => {} } - Poll::Pending => { - // device not ready for reading, moving on .. + + if self.stats.poll(&self.node, cx).is_ready() { + continue; } - } - // After any state change, check what the new timeout is and reset it if necessary. - if self.connections_state.poll_timeout(cx).is_ready() { - cx.waker().wake_by_ref() + return Poll::Pending; } - - Poll::Pending } } @@ -204,25 +343,29 @@ where #[tracing::instrument(level = "trace", skip(private_key, callbacks))] pub fn new(private_key: StaticSecret, callbacks: CB) -> Result { let callbacks = CallbackErrorFacade(callbacks); - let connections_state = ConnectionState::new(private_key)?; + let io = Io::new()?; // TODO: Eventually, this should move into the `connlib-client-android` crate. #[cfg(target_os = "android")] { - if let Some(ip4_socket) = connections_state.sockets.ip4_socket_fd() { + if let Some(ip4_socket) = io.sockets_ref().ip4_socket_fd() { callbacks.protect_file_descriptor(ip4_socket)?; } - if let Some(ip6_socket) = connections_state.sockets.ip6_socket_fd() { + if let Some(ip6_socket) = io.sockets_ref().ip6_socket_fd() { callbacks.protect_file_descriptor(ip6_socket)?; } } Ok(Self { - device: Device::new(), callbacks, role_state: Default::default(), - connections_state, - read_buf: Box::new([0u8; MAX_UDP_SIZE]), + node: Node::new(private_key), + write_buf: Box::new([0u8; MAX_UDP_SIZE]), + ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + device_read_buf: Box::new([0u8; MAX_UDP_SIZE]), + io, + stats: Stats::new(Duration::from_secs(60)), }) } @@ -231,187 +374,9 @@ where } } -struct ConnectionState { - pub node: Node, - write_buf: Box<[u8; MAX_UDP_SIZE]>, - timeout: Option>>, - stats_timer: tokio::time::Interval, - sockets: Sockets, -} - -impl ConnectionState -where - TId: Eq + Hash + Copy + fmt::Display, -{ - fn new(private_key: StaticSecret) -> Result { - Ok(ConnectionState { - node: Node::new(private_key), - write_buf: Box::new([0; MAX_UDP_SIZE]), - sockets: Sockets::new()?, - stats_timer: tokio::time::interval(Duration::from_secs(60)), - timeout: None, - }) - } - - fn send(&mut self, id: TId, packet: IpPacket) { - let to = packet.destination(); - - if let Err(e) = self.try_send(id, packet) { - tracing::warn!(%to, %id, "Failed to send packet: {e}"); - } - } - - fn try_send(&mut self, id: TId, packet: IpPacket) -> Result<()> { - // TODO: handle NotConnected - let Some(transmit) = self.node.encapsulate(id, packet.into(), Instant::now())? else { - return Ok(()); - }; - - self.sockets.try_send(&transmit)?; - - Ok(()) - } - - // TODO: passing the peer_store looks weird, we can just remove ConnectionState and move everything into Tunnel, there's no Mutexes any longer that justify this separation - fn poll_sockets( - &mut self, - device: &mut Device, - peer_store: &mut PeerStore, - cx: &mut Context<'_>, - ) -> Poll> - where - TTransform: PacketTransform, - TResource: Clone, - { - let received = match ready!(self.sockets.poll_recv_from(cx)) { - Ok(received) => received, - Err(e) => { - tracing::warn!("Failed to read socket: {e}"); - - cx.waker().wake_by_ref(); // Immediately schedule a new wake-up. - return Poll::Pending; - } - }; - - for received in received { - let Received { - local, - from, - packet, - } = received; - - let (conn_id, packet) = match self.node.decapsulate( - local, - from, - packet.as_ref(), - std::time::Instant::now(), - self.write_buf.as_mut(), - ) { - Ok(Some(packet)) => packet, - Ok(None) => { - continue; - } - Err(e) => { - tracing::warn!(%local, %from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}"); - - continue; - } - }; - - let Some(peer) = peer_store.get_mut(&conn_id) else { - tracing::error!(%conn_id, %local, %from, "Couldn't find connection"); - - continue; - }; - - let packet = match peer.untransform(packet.into()) { - Ok(packet) => packet, - Err(e) => { - tracing::warn!(%conn_id, %local, %from, "Failed to transform packet: {e}"); - - continue; - } - }; - - device.write(packet.as_immutable())?; - } - - Poll::Ready(Ok(())) - } - - fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.stats_timer.poll_tick(cx).is_ready() { - let (node_stats, conn_stats) = self.node.stats(); - - tracing::debug!(target: "connlib::stats", "{node_stats:?}"); - - for (id, stats) in conn_stats { - tracing::debug!(target: "connlib::stats", %id, "{stats:?}"); - } - - cx.waker().wake_by_ref(); - } - - if let Err(e) = ready!(self.sockets.poll_send_ready(cx)) { - tracing::warn!("Failed to poll sockets for readiness: {e}"); - }; - - while let Some(transmit) = self.node.poll_transmit() { - if let Err(e) = self.sockets.try_send(&transmit) { - tracing::warn!(src = ?transmit.src, dst = %transmit.dst, "Failed to send UDP packet: {e}"); - } - } - - match self.node.poll_event() { - Some(snownet::Event::SignalIceCandidate { - connection, - candidate, - }) => { - return Poll::Ready(Event::SignalIceCandidate { - conn_id: connection, - candidate, - }); - } - Some(snownet::Event::ConnectionFailed(id)) => { - return Poll::Ready(Event::StopPeer(id)); - } - _ => {} - } - - Poll::Pending - } - - fn poll_timeout(&mut self, cx: &mut Context<'_>) -> Poll<()> { - if let Some(timeout) = self.node.poll_timeout() { - let timeout = tokio::time::Instant::from_std(timeout); - - match self.timeout.as_mut() { - Some(existing_timeout) if existing_timeout.deadline() != timeout => { - existing_timeout.as_mut().reset(timeout) - } - Some(_) => {} - None => self.timeout = Some(Box::pin(tokio::time::sleep_until(timeout))), - } - } - - if let Some(timeout) = self.timeout.as_mut() { - ready!(timeout.poll_unpin(cx)); - self.node.handle_timeout(timeout.deadline().into()); - - return Poll::Ready(()); - } - - // Technically, we should set a waker here because we don't have a timer. - // But the only place where we set a timer is a few lines up. - // That is the same path that will re-poll it so there is no point in using a waker. - // We might want to consider making a `MaybeSleep` type that encapsulates a waker so we don't need to think about it as hard. - Poll::Pending - } -} - -pub enum Event { +pub enum ClientEvent { SignalIceCandidate { - conn_id: TId, + conn_id: GatewayId, candidate: String, }, ConnectionIntent { @@ -421,6 +386,11 @@ pub enum Event { RefreshResources { connections: Vec, }, - SendPacket(IpPacket<'static>), - StopPeer(TId), +} + +pub enum GatewayEvent { + SignalIceCandidate { + conn_id: ClientId, + candidate: String, + }, } diff --git a/rust/connlib/tunnel/src/sockets.rs b/rust/connlib/tunnel/src/sockets.rs index 726c5f9e7..d0bc547d2 100644 --- a/rust/connlib/tunnel/src/sockets.rs +++ b/rust/connlib/tunnel/src/sockets.rs @@ -9,16 +9,16 @@ use std::{ }; use tokio::{io::Interest, net::UdpSocket}; -use crate::{Error, Result, MAX_UDP_SIZE}; +use crate::Result; use snownet::Transmit; pub struct Sockets { - socket_v4: Option>, - socket_v6: Option>, + socket_v4: Option, + socket_v6: Option, } impl Sockets { - pub fn new() -> crate::Result { + pub fn new() -> io::Result { let socket_v4 = Socket::ip4(); let socket_v6 = Socket::ip6(); @@ -33,10 +33,10 @@ impl Sockets { tracing::error!("Failed to bind IPv4 socket: {e4}"); tracing::error!("Failed to bind IPv6 socket: {e6}"); - return Err(Error::Io(io::Error::new( + return Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "Unable to bind to interfaces", - ))); + )); } _ => (), } @@ -68,42 +68,58 @@ impl Sockets { self.socket_v6.as_ref().map(|s| s.socket.as_raw_fd()) } - pub fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(socket) = self.socket_v4.as_mut() { + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + if let Some(socket) = self.socket_v4.as_ref() { ready!(socket.poll_send_ready(cx))?; } - if let Some(socket) = self.socket_v6.as_mut() { + if let Some(socket) = self.socket_v6.as_ref() { ready!(socket.poll_send_ready(cx))?; } Poll::Ready(Ok(())) } - pub fn try_send(&mut self, transmit: &Transmit) -> Result { + pub fn try_send(&self, transmit: &Transmit) -> io::Result { match transmit.dst { SocketAddr::V4(_) => { - let socket = self.socket_v4.as_ref().ok_or(Error::NoIpv4)?; + let socket = self.socket_v4.as_ref().ok_or(io::Error::new( + io::ErrorKind::NotConnected, + "no IPv4 socket", + ))?; Ok(socket.try_send_to(transmit.src, transmit.dst, &transmit.payload)?) } SocketAddr::V6(_) => { - let socket = self.socket_v6.as_ref().ok_or(Error::NoIpv6)?; + let socket = self.socket_v6.as_ref().ok_or(io::Error::new( + io::ErrorKind::NotConnected, + "no IPv6 socket", + ))?; Ok(socket.try_send_to(transmit.src, transmit.dst, &transmit.payload)?) } } } - pub fn poll_recv_from<'a>( - &'a mut self, + pub fn poll_recv_from<'b>( + &self, + ip4_buffer: &'b mut [u8], + ip6_buffer: &'b mut [u8], cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { let mut iter = PacketIter::new(); - if let Some(Poll::Ready(packets)) = self.socket_v4.as_mut().map(|s| s.poll_recv_from(cx)) { + if let Some(Poll::Ready(packets)) = self + .socket_v4 + .as_ref() + .map(|s| s.poll_recv_from(ip4_buffer, cx)) + { iter.ip4 = Some(packets?); } - if let Some(Poll::Ready(packets)) = self.socket_v6.as_mut().map(|s| s.poll_recv_from(cx)) { + if let Some(Poll::Ready(packets)) = self + .socket_v6 + .as_ref() + .map(|s| s.poll_recv_from(ip6_buffer, cx)) + { iter.ip6 = Some(packets?); } @@ -159,15 +175,14 @@ pub struct Received<'a> { pub packet: &'a [u8], } -struct Socket { +struct Socket { state: UdpSocketState, port: u16, socket: UdpSocket, - buffer: Box<[u8; N]>, } -impl Socket { - fn ip4() -> Result> { +impl Socket { + fn ip4() -> Result { let socket = make_socket(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))?; let port = socket.local_addr()?.port(); @@ -175,11 +190,10 @@ impl Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, socket: tokio::net::UdpSocket::from_std(socket)?, - buffer: Box::new([0u8; N]), }) } - fn ip6() -> Result> { + fn ip6() -> Result { let socket = make_socket(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))?; let port = socket.local_addr()?.port(); @@ -187,23 +201,22 @@ impl Socket { state: UdpSocketState::new(UdpSockRef::from(&socket))?, port, socket: tokio::net::UdpSocket::from_std(socket)?, - buffer: Box::new([0u8; N]), }) } #[allow(clippy::type_complexity)] fn poll_recv_from<'b>( - &'b mut self, + &self, + buffer: &'b mut [u8], cx: &mut Context<'_>, ) -> Poll>>> { let Socket { port, socket, - buffer, state, } = self; - let bufs = &mut [IoSliceMut::new(buffer.as_mut())]; + let bufs = &mut [IoSliceMut::new(buffer)]; let mut meta = RecvMeta::default(); loop { @@ -241,7 +254,7 @@ impl Socket { } } - fn poll_send_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { self.socket.poll_send_ready(cx) } diff --git a/rust/connlib/tunnel/src/stats.rs b/rust/connlib/tunnel/src/stats.rs new file mode 100644 index 000000000..71b47416b --- /dev/null +++ b/rust/connlib/tunnel/src/stats.rs @@ -0,0 +1,39 @@ +use core::fmt; +use std::hash::Hash; +use std::{ + task::{ready, Context, Poll}, + time::Duration, +}; + +pub struct Stats { + interval: tokio::time::Interval, +} + +impl Stats { + pub fn new(interval: Duration) -> Self { + Self { + interval: tokio::time::interval(interval), + } + } + + pub fn poll( + &mut self, + node: &snownet::Node, + cx: &mut Context<'_>, + ) -> Poll<()> + where + TId: fmt::Display + Copy + Eq + PartialEq + Hash, + { + ready!(self.interval.poll_tick(cx)); + + let (node_stats, conn_stats) = node.stats(); + + tracing::debug!(target: "connlib::stats", "{node_stats:?}"); + + for (id, stats) in conn_stats { + tracing::debug!(target: "connlib::stats", %id, "{stats:?}"); + } + + Poll::Ready(()) + } +} diff --git a/rust/connlib/tunnel/src/utils.rs b/rust/connlib/tunnel/src/utils.rs index df3597240..9904ab514 100644 --- a/rust/connlib/tunnel/src/utils.rs +++ b/rust/connlib/tunnel/src/utils.rs @@ -1,6 +1,6 @@ use crate::REALM; use connlib_shared::messages::Relay; -use std::{collections::HashSet, net::SocketAddr}; +use std::{collections::HashSet, net::SocketAddr, time::Instant}; pub fn stun(relays: &[Relay], predicate: impl Fn(&SocketAddr) -> bool) -> HashSet { relays @@ -37,3 +37,12 @@ pub fn turn( .filter(|(socket, _, _, _)| predicate(socket)) .collect() } + +pub fn earliest(left: Option, right: Option) -> Option { + match (left, right) { + (None, None) => None, + (Some(left), Some(right)) => Some(std::cmp::min(left, right)), + (Some(left), None) => Some(left), + (None, Some(right)) => Some(right), + } +} diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 7653e7f78..213f92b30 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -12,7 +12,7 @@ use connlib_shared::{ #[cfg(not(target_os = "windows"))] use dns_lookup::{AddrInfoHints, AddrInfoIter, LookupError}; use either::Either; -use firezone_tunnel::{Event, GatewayTunnel, ResolvedResourceDescriptionDns}; +use firezone_tunnel::{GatewayTunnel, ResolvedResourceDescriptionDns}; use ip_network::IpNetwork; use phoenix_channel::PhoenixChannel; use std::convert::Infallible; @@ -48,11 +48,11 @@ impl Eventloop { #[tracing::instrument(name = "Eventloop::poll", skip_all, level = "debug")] pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.tunnel.poll_next_event(cx)? { - Poll::Ready(firezone_tunnel::Event::SignalIceCandidate { + match self.tunnel.poll_next_event(cx) { + Poll::Ready(Ok(firezone_tunnel::GatewayEvent::SignalIceCandidate { conn_id: client, candidate, - }) => { + })) => { tracing::debug!(%client, %candidate, "Sending ICE candidate to client"); self.portal.send( @@ -65,10 +65,10 @@ impl Eventloop { continue; } - Poll::Ready(Event::ConnectionIntent { .. }) => { - unreachable!("Not used on the gateway, split the events!") + Poll::Ready(Err(e)) => { + tracing::warn!("Tunnel error: {e}"); + continue; } - Poll::Ready(_) => continue, Poll::Pending => {} } @@ -97,17 +97,16 @@ impl Eventloop { ); // TODO: If outbound request fails, cleanup connection. - continue; } Err(e) => { let client = req.client.id; self.tunnel.cleanup_connection(&client); tracing::debug!(%client, "Connection request failed: {:#}", anyhow::Error::new(e)); - - continue; } } + + continue; } Poll::Ready((Ok(Ok(resource)), Either::Right(req))) => { let maybe_domain_response = self.tunnel.allow_access( @@ -127,8 +126,9 @@ impl Eventloop { ), }), ); - continue; } + + continue; } Poll::Ready((Ok(Err(dns_error)), Either::Left(req))) => { tracing::debug!(client = %req.client.id, reference = %req.reference, "Failed to resolve domains as part of connection request: {dns_error}"); diff --git a/scripts/tests/lib.sh b/scripts/tests/lib.sh index 3bf67daa1..ae2f98945 100755 --- a/scripts/tests/lib.sh +++ b/scripts/tests/lib.sh @@ -13,5 +13,5 @@ function remove_iptables_drop_rules() { } function client_curl_resource() { - docker compose exec -it client curl --fail -i 172.20.0.100 + docker compose exec -it client curl --max-time 30 --fail -i 172.20.0.100 }