diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 81fdf9c8b..c084d53f3 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -206,7 +206,15 @@ where self.role_state.system_resolvers.clone(), ); - let dns_mapping = sentinel_dns_mapping(&effective_dns_servers); + let dns_mapping = sentinel_dns_mapping( + &effective_dns_servers, + self.role_state + .dns_mapping() + .left_values() + .copied() + .map(Into::into) + .collect_vec(), + ); self.role_state.set_dns_mapping(dns_mapping.clone()); self.io.set_upstream_dns_servers(dns_mapping.clone()); @@ -976,8 +984,11 @@ fn effective_dns_servers( .collect() } -fn sentinel_dns_mapping(dns: &[DnsServer]) -> BiMap { - let mut ip_provider = IpProvider::for_stub_dns_servers(); +fn sentinel_dns_mapping( + dns: &[DnsServer], + old_sentinels: Vec, +) -> BiMap { + let mut ip_provider = IpProvider::for_stub_dns_servers(old_sentinels); dns.iter() .cloned() @@ -1025,36 +1036,34 @@ impl IpProvider { IpProvider::new( IPV4_RESOURCES.parse().unwrap(), IPV6_RESOURCES.parse().unwrap(), - Some(DNS_SENTINELS_V4.parse().unwrap()), - Some(DNS_SENTINELS_V6.parse().unwrap()), + vec![ + DNS_SENTINELS_V4.parse().unwrap(), + DNS_SENTINELS_V6.parse().unwrap(), + ], ) } - pub fn for_stub_dns_servers() -> Self { + pub fn for_stub_dns_servers(exclusions: Vec) -> Self { IpProvider::new( DNS_SENTINELS_V4.parse().unwrap(), DNS_SENTINELS_V6.parse().unwrap(), - None, - None, + exclusions, ) } - fn new( - ipv4: Ipv4Network, - ipv6: Ipv6Network, - exclusion_v4: Option, - exclusion_v6: Option, - ) -> Self { + fn new(ipv4: Ipv4Network, ipv6: Ipv6Network, exclusions: Vec) -> Self { Self { - ipv4: Box::new( + ipv4: Box::new({ + let exclusions = exclusions.clone(); ipv4.hosts() - .filter(move |ip| !exclusion_v4.is_some_and(|e| e.contains(*ip))), - ), - ipv6: Box::new( + .filter(move |ip| !exclusions.iter().any(|e| e.contains(*ip))) + }), + ipv6: Box::new({ + let exclusions = exclusions.clone(); ipv6.subnets_with_prefix(128) .map(|ip| ip.network_address()) - .filter(move |ip| !exclusion_v6.is_some_and(|e| e.contains(*ip))), - ), + .filter(move |ip| !exclusions.iter().any(|e| e.contains(*ip))) + }), } } @@ -1152,12 +1161,64 @@ mod tests { assert_eq!(client_state.poll_event(), Some(Event::RefreshInterface)); } + #[test] + fn sentinel_dns_works() { + let servers = dns_list(); + let sentinel_dns = sentinel_dns_mapping(&servers, vec![]); + + for server in servers { + assert!(sentinel_dns + .get_by_right(&server) + .is_some_and(|s| sentinel_ranges().iter().any(|e| e.contains(*s)))) + } + } + + #[test] + fn sentinel_dns_excludes_old_ones() { + let servers = dns_list(); + let sentinel_dns_old = sentinel_dns_mapping(&servers, vec![]); + let sentinel_dns_new = sentinel_dns_mapping( + &servers, + sentinel_dns_old + .left_values() + .copied() + .map(Into::into) + .collect_vec(), + ); + + assert!( + HashSet::<&IpAddr>::from_iter(sentinel_dns_old.left_values()) + .is_disjoint(&HashSet::from_iter(sentinel_dns_new.left_values())) + ) + } + impl ClientState { fn for_test() -> ClientState { ClientState::new(StaticSecret::random_from_rng(OsRng)) } } + fn sentinel_ranges() -> Vec { + vec![ + IpNetwork::from_str(DNS_SENTINELS_V4).unwrap(), + IpNetwork::from_str(DNS_SENTINELS_V6).unwrap(), + ] + } + + fn dns_list() -> Vec { + vec![ + DnsServer::IpPort(IpDnsServer { + address: "1.1.1.1:53".parse().unwrap(), + }), + DnsServer::IpPort(IpDnsServer { + address: "1.0.0.1:53".parse().unwrap(), + }), + DnsServer::IpPort(IpDnsServer { + address: "[2606:4700:4700::1111]:53".parse().unwrap(), + }), + ] + } + fn ip(addr: &str) -> IpAddr { addr.parse().unwrap() }