diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 65e479742..d197c3d53 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -10,12 +10,11 @@ use connlib_shared::messages::{ ResourceId, }; use connlib_shared::{Callbacks, DomainName, Error, Result, StaticSecret}; -use ip_network::IpNetwork; use ip_packet::{IpPacket, MutableIpPacket}; use secrecy::{ExposeSecret as _, Secret}; use snownet::{RelaySocket, ServerNode}; use std::collections::{HashSet, VecDeque}; -use std::net::SocketAddr; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::time::{Duration, Instant}; const PEERS_IPV4: &str = "100.64.0.0/11"; @@ -54,7 +53,8 @@ where key: Secret, offer: Offer, client: PublicKey, - ips: Vec, + ipv4: Ipv4Addr, + ipv6: Ipv6Addr, relays: Vec, domain: Option, expires_at: Option>, @@ -70,7 +70,8 @@ where }, }, client, - ips, + ipv4, + ipv6, stun(&relays, |addr| self.io.sockets_ref().can_handle(addr)), turn(&relays), domain, @@ -231,7 +232,8 @@ impl GatewayState { client_id: ClientId, offer: snownet::Offer, client: PublicKey, - ips: Vec, + ipv4: Ipv4Addr, + ipv6: Ipv6Addr, stun_servers: HashSet, turn_servers: HashSet<(RelayId, RelaySocket, String, String, String)>, domain: Option, @@ -253,7 +255,7 @@ impl GatewayState { self.node .accept_connection(client_id, offer, client, stun_servers, turn_servers, now); - let mut peer = ClientOnGateway::new(client_id, &ips); + let mut peer = ClientOnGateway::new(client_id, ipv4, ipv6); peer.add_resource( resource.addresses(), @@ -262,7 +264,7 @@ impl GatewayState { expires_at, ); - self.peers.insert(peer, &ips); + self.peers.insert(peer, &[ipv4.into(), ipv6.into()]); Ok(ConnectionAccepted { ice_parameters: Answer { diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 55a33d97c..d52ed5c1a 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -1,5 +1,5 @@ use std::collections::{HashMap, HashSet}; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::time::Instant; use bimap::BiMap; @@ -169,20 +169,23 @@ impl GatewayOnClient { } impl ClientOnGateway { - pub(crate) fn new(id: ClientId, ips: &[IpNetwork]) -> ClientOnGateway { - let mut allowed_ips = IpNetworkTable::new(); - for ip in ips { - allowed_ips.insert(*ip, ()); - } - + pub(crate) fn new(id: ClientId, ipv4: Ipv4Addr, ipv6: Ipv6Addr) -> ClientOnGateway { ClientOnGateway { id, - allowed_ips, + ipv4, + ipv6, resources: HashMap::new(), filters: IpNetworkTable::new(), } } + /// A client is only allowed to send packets from their (portal-assigned) tunnel IPs. + /// + /// Failure to enforce this would allow one client to send traffic masquarading as a different client. + fn allowed_ips(&self) -> [IpAddr; 2] { + [IpAddr::from(self.ipv4), IpAddr::from(self.ipv6)] + } + pub(crate) fn is_emptied(&self) -> bool { self.resources.is_empty() } @@ -268,14 +271,10 @@ impl ClientOnGateway { &self, packet: &MutableIpPacket<'_>, ) -> Result<(), connlib_shared::Error> { - if self.allowed_ips.longest_match(packet.source()).is_none() { + if !self.allowed_ips().contains(&packet.source()) { return Err(connlib_shared::Error::UnallowedPacket { src: packet.source(), - allowed_ips: self - .allowed_ips - .iter() - .map(|(ip, &())| ip.network_address()) - .collect(), + allowed_ips: HashSet::from(self.allowed_ips()), }); } @@ -379,14 +378,18 @@ struct ResourceOnGateway { /// The state of one client on a gateway. pub struct ClientOnGateway { id: ClientId, - allowed_ips: IpNetworkTable<()>, + ipv4: Ipv4Addr, + ipv6: Ipv6Addr, resources: HashMap>, filters: IpNetworkTable, } #[cfg(test)] mod tests { - use std::{net::IpAddr, time::Duration}; + use std::{ + net::{Ipv4Addr, Ipv6Addr}, + time::Duration, + }; use chrono::Utc; use connlib_shared::messages::{ @@ -399,7 +402,7 @@ mod tests { #[test] fn gateway_filters_expire_individually() { - let mut peer = ClientOnGateway::new(client_id(), &[source_v4_addr().into()]); + let mut peer = ClientOnGateway::new(client_id(), source_v4_addr(), source_v6_addr()); let now = Utc::now(); let then = now + Duration::from_secs(10); let after_then = then + Duration::from_secs(10); @@ -424,7 +427,7 @@ mod tests { ); let tcp_packet = ip_packet::make::tcp_packet( - source_v4_addr(), + source_v4_addr().into(), cidr_v4_resource().hosts().next().unwrap().into(), 5401, 80, @@ -432,7 +435,7 @@ mod tests { ); let udp_packet = ip_packet::make::udp_packet( - source_v4_addr(), + source_v4_addr().into(), cidr_v4_resource().hosts().next().unwrap().into(), 5401, 80, @@ -464,10 +467,14 @@ mod tests { )); } - fn source_v4_addr() -> IpAddr { + fn source_v4_addr() -> Ipv4Addr { "100.64.0.1".parse().unwrap() } + fn source_v6_addr() -> Ipv6Addr { + "fd00:2021:1111::1".parse().unwrap() + } + fn cidr_v4_resource() -> Ipv4Network { "10.0.0.0/24".parse().unwrap() } @@ -509,17 +516,24 @@ mod proptests { fn gateway_accepts_allowed_packet( #[strategy(client_id())] client_id: ClientId, #[strategy(vec![resource_id(); 5])] resources_id: Vec, - #[strategy(source_resource_and_host_within())] config: (IpAddr, IpNetwork, IpAddr), + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(cidr_with_host())] config: (IpNetwork, IpAddr), #[strategy(collection::vec(filters_with_allowed_protocol(), 1..=5))] protocol_config: Vec< (Filters, Protocol), >, #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src, resource_addr, dest) = config; + let (resource_addr, dest) = config; + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; let mut filters = protocol_config.iter(); // This test could be extended to test multiple src - let mut peer = ClientOnGateway::new(client_id, &[src.into()]); + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); let mut resource_addr = Some(resource_addr); let mut resources = 0; @@ -549,39 +563,31 @@ mod proptests { fn gateway_accepts_allowed_packet_multiple_ips_resource( #[strategy(client_id())] client_id: ClientId, #[strategy(resource_id())] resource_id: ResourceId, - #[strategy(collection::vec(source_resource_and_host_within(), 1..=5))] config: Vec<( - IpAddr, - IpNetwork, - IpAddr, - )>, + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(collection::vec(cidr_with_host(), 1..=5))] config: Vec<(IpNetwork, IpAddr)>, #[strategy(filters_with_allowed_protocol())] protocol_config: (Filters, Protocol), #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src, resource_addr, dest): (Vec<_>, Vec<_>, Vec<_>) = config.into_iter().multiunzip(); + let (resource_addr, dest): (Vec<_>, Vec<_>) = config.into_iter().unzip(); let (filters, protocol) = protocol_config; - let mut peer = ClientOnGateway::new( - client_id, - &src.clone().into_iter().map(Into::into).collect_vec(), - ); + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); peer.add_resource(resource_addr, resource_id, filters, None); for dest in dest { - for src in &src { - if dest.is_ipv4() == src.is_ipv4() { - let packet = match protocol { - Protocol::Tcp { dport } => { - tcp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Udp { dport } => { - udp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Icmp => icmp_request_packet(*src, dest, 1, 0), - }; - assert!(peer.ensure_allowed(&packet).is_ok()); - } - } + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; + let packet = match protocol { + Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + }; + assert!(peer.ensure_allowed(&packet).is_ok()); } } @@ -589,13 +595,13 @@ mod proptests { fn gateway_accepts_allowed_packet_multiple_ips_resource_multiple_adds( #[strategy(client_id())] client_id: ClientId, #[strategy(resource_id())] resource_id: ResourceId, - #[strategy(collection::vec(source_resource_and_host_within(), 1..=5))] config_res_1: Vec<( - IpAddr, + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(collection::vec(cidr_with_host(), 1..=5))] config_res_1: Vec<( IpNetwork, IpAddr, )>, - #[strategy(collection::vec(source_resource_and_host_within(), 1..=5))] config_res_2: Vec<( - IpAddr, + #[strategy(collection::vec(cidr_with_host(), 1..=5))] config_res_2: Vec<( IpNetwork, IpAddr, )>, @@ -603,54 +609,40 @@ mod proptests { #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src_1, resource_addr_1, dest_1): (Vec<_>, Vec<_>, Vec<_>) = - config_res_1.into_iter().multiunzip(); - let (src_2, resource_addr_2, dest_2): (Vec<_>, Vec<_>, Vec<_>) = - config_res_2.into_iter().multiunzip(); + let (resource_addr_1, dest_1): (Vec<_>, Vec<_>) = config_res_1.into_iter().unzip(); + let (resource_addr_2, dest_2): (Vec<_>, Vec<_>) = config_res_2.into_iter().unzip(); let (filters, protocol) = protocol_config; - let mut src = Vec::new(); - src.extend(src_1); - src.extend(src_2); - let mut peer = ClientOnGateway::new( - client_id, - &src.clone().into_iter().map(Into::into).collect_vec(), - ); + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); peer.add_resource(resource_addr_1, resource_id, filters.clone(), None); peer.add_resource(resource_addr_2, resource_id, filters, None); for dest in dest_1 { - for src in &src { - if dest.is_ipv4() == src.is_ipv4() { - let packet = match protocol { - Protocol::Tcp { dport } => { - tcp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Udp { dport } => { - udp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Icmp => icmp_request_packet(*src, dest, 1, 0), - }; - assert!(peer.ensure_allowed(&packet).is_ok()); - } - } + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; + let packet = match protocol { + Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + }; + assert!(peer.ensure_allowed(&packet).is_ok()); } for dest in dest_2 { - for src in &src { - if dest.is_ipv4() == src.is_ipv4() { - let packet = match protocol { - Protocol::Tcp { dport } => { - tcp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Udp { dport } => { - udp_packet(*src, dest, sport, dport, payload.clone()) - } - Protocol::Icmp => icmp_request_packet(*src, dest, 1, 0), - }; - assert!(peer.ensure_allowed(&packet).is_ok()); - } - } + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; + let packet = match protocol { + Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload.clone()), + Protocol::Icmp => icmp_request_packet(src, dest, 1, 0), + }; + assert!(peer.ensure_allowed(&packet).is_ok()); } } @@ -658,15 +650,22 @@ mod proptests { fn gateway_accepts_different_resources_with_same_ip_packet( #[strategy(client_id())] client_id: ClientId, #[strategy(vec![resource_id(); 10])] resources_ids: Vec, - #[strategy(source_resource_and_host_within())] config: (IpAddr, IpNetwork, IpAddr), + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(cidr_with_host())] config: (IpNetwork, IpAddr), #[strategy(collection::vec(filters_with_allowed_protocol(), 1..=10))] protocol_config: Vec< (Filters, Protocol), >, #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src, resource_addr, dest) = config; - let mut peer = ClientOnGateway::new(client_id, &[src.into()]); + let (resource_addr, dest) = config; + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); let mut resources_ids = resources_ids.iter(); for (filters, _) in &protocol_config { // This test could be extended to test multiple src @@ -693,15 +692,22 @@ mod proptests { fn gateway_reject_unallowed_packet( #[strategy(client_id())] client_id: ClientId, #[strategy(resource_id())] resource_id: ResourceId, - #[strategy(source_resource_and_host_within())] config: (IpAddr, IpNetwork, IpAddr), + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(cidr_with_host())] config: (IpNetwork, IpAddr), #[strategy(filters_with_rejected_protocol())] protocol_config: (Filters, Protocol), #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src, resource_addr, dest) = config; + let (resource_addr, dest) = config; + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; let (filters, protocol) = protocol_config; // This test could be extended to test multiple src - let mut peer = ClientOnGateway::new(client_id, &[src.into()]); + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); let packet = match protocol { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload), Protocol::Udp { dport } => udp_packet(src, dest, sport, dport, payload), @@ -721,7 +727,9 @@ mod proptests { #[strategy(client_id())] client_id: ClientId, #[strategy(resource_id())] resource_id_allowed: ResourceId, #[strategy(resource_id())] resource_id_removed: ResourceId, - #[strategy(source_resource_and_host_within())] config: (IpAddr, IpNetwork, IpAddr), + #[strategy(any::())] src_v4: Ipv4Addr, + #[strategy(any::())] src_v6: Ipv6Addr, + #[strategy(cidr_with_host())] config: (IpNetwork, IpAddr), #[strategy(non_overlapping_non_empty_filters_with_allowed_protocol())] protocol_config: ( (Filters, Protocol), (Filters, Protocol), @@ -729,11 +737,16 @@ mod proptests { #[strategy(any::())] sport: u16, #[strategy(any::>())] payload: Vec, ) { - let (src, resource_addr, dest) = config; + let (resource_addr, dest) = config; + let src = if dest.is_ipv4() { + src_v4.into() + } else { + src_v6.into() + }; let ((filters_allowed, protocol_allowed), (filters_removed, protocol_removed)) = protocol_config; // This test could be extended to test multiple src - let mut peer = ClientOnGateway::new(client_id, &[src.into()]); + let mut peer = ClientOnGateway::new(client_id, src_v4, src_v6); let packet_allowed = match protocol_allowed { Protocol::Tcp { dport } => tcp_packet(src, dest, sport, dport, payload.clone()), @@ -801,36 +814,22 @@ mod proptests { }) } - fn source_resource_and_host_within() -> impl Strategy { - any::().prop_flat_map(|is_v4| { - if is_v4 { - cidrv4_with_host() - .prop_flat_map(|(net, dst)| { - any::().prop_map(move |src| (src.into(), net.into(), dst.into())) - }) - .boxed() - } else { - cidrv6_with_host() - .prop_flat_map(|(net, dst)| { - any::().prop_map(move |src| (src.into(), net.into(), dst.into())) - }) - .boxed() - } - }) + fn cidr_with_host() -> impl Strategy { + prop_oneof![cidrv4_with_host(), cidrv6_with_host()] } // max netmask here picked arbitrarily since using max size made the tests run for too long - fn cidrv6_with_host() -> impl Strategy { + fn cidrv6_with_host() -> impl Strategy { (1usize..=8).prop_flat_map(|host_mask| { ip6_network(host_mask) - .prop_flat_map(|net| host_v6(net).prop_map(move |host| (net, host))) + .prop_flat_map(|net| host_v6(net).prop_map(move |host| (net.into(), host.into()))) }) } - fn cidrv4_with_host() -> impl Strategy { + fn cidrv4_with_host() -> impl Strategy { (1usize..=8).prop_flat_map(|host_mask| { ip4_network(host_mask) - .prop_flat_map(|net| host_v4(net).prop_map(move |host| (net, host))) + .prop_flat_map(|net| host_v4(net).prop_map(move |host| (net.into(), host.into()))) }) } diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index a0f6dfc6e..88044bcb0 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -868,10 +868,8 @@ impl TunnelTest { }, }, self.client.state.public_key(), - vec![ - self.client.tunnel_ip4.into(), - self.client.tunnel_ip6.into(), - ], + self.client.tunnel_ip4, + self.client.tunnel_ip6, HashSet::default(), HashSet::default(), new_connection.client_payload.domain, diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 4e5f14684..f6baf5d96 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -219,14 +219,13 @@ impl Eventloop { .inspect_err(|e| tracing::debug!(client = %req.client.id, reference = %req.reference, "DNS resolution timed out as part of connection request: {e}")) .unwrap_or_default(); - let ips = req.client.peer.ips(); - match self.tunnel.accept( req.client.id, req.client.peer.preshared_key, req.client.payload.ice_parameters, PublicKey::from(req.client.peer.public_key.0), - ips, + req.client.peer.ipv4, + req.client.peer.ipv6, req.relays, req.client.payload.domain, req.expires_at,