feat(connlib): allow users to disable resources (#6164)

This is just the API part for #6074

We expose a new API `set_disabled_resources` which given a set of
resource ids it does the following:
* Disconnect any active connection depending only on this resource
* Prevent any new connection with that resource id being established

The `set_disabled_resources` API is purposely not stateful. In other
words, resources cannot be incrementally enabled or disabled. Instead,
clients always need to send the latest state, i.e. all resources that
should be disabled. `connlib` will figure out the diff and correctly
enable / disable resources as necessary. Thus, enabling a resource is
done by calling `set_disabled_resources` without the previously disabled
resource ID.

Initially, this will only be used for the internet resource but the use
can be expanded for any other resource.
This commit is contained in:
Gabi
2024-08-05 18:13:04 -03:00
committed by GitHub
parent 023d05ece1
commit 181b81d24a
7 changed files with 119 additions and 18 deletions

File diff suppressed because one or more lines are too long

View File

@@ -82,6 +82,18 @@ impl ClientTunnel {
});
}
pub fn set_disabled_resources(&mut self, new_disabled_resources: HashSet<ResourceId>) {
self.role_state
.set_disabled_resource(new_disabled_resources);
self.role_state
.buffered_events
.push_back(ClientEvent::TunRoutesUpdated {
ip4: self.role_state.routes().filter_map(utils::ipv4).collect(),
ip6: self.role_state.routes().filter_map(utils::ipv6).collect(),
});
}
pub fn set_tun(&mut self, tun: Box<dyn Tun>) {
self.io.device_mut().set_tun(tun);
}
@@ -254,7 +266,7 @@ pub struct ClientState {
sites_status: HashMap<SiteId, Status>,
/// All CIDR resources we know about, indexed by the IP range they cover (like `1.1.0.0/8`).
cidr_resources: IpNetworkTable<ResourceDescriptionCidr>,
active_cidr_resources: IpNetworkTable<ResourceDescriptionCidr>,
/// All resources indexed by their ID.
resources_by_id: HashMap<ResourceId, ResourceDescription>,
@@ -276,6 +288,9 @@ pub struct ClientState {
/// Configuration of the TUN device, when it is up.
interface_config: Option<InterfaceConfig>,
/// Resources that have been disabled by the UI
disabled_resources: HashSet<ResourceId>,
buffered_events: VecDeque<ClientEvent>,
buffered_packets: VecDeque<IpPacket<'static>>,
}
@@ -296,7 +311,7 @@ impl ClientState {
Self {
awaiting_connection_details: Default::default(),
resources_gateways: Default::default(),
cidr_resources: IpNetworkTable::new(),
active_cidr_resources: IpNetworkTable::new(),
resources_by_id: Default::default(),
peers: Default::default(),
dns_mapping: Default::default(),
@@ -310,6 +325,7 @@ impl ClientState {
gateways_site: Default::default(),
mangled_dns_queries: Default::default(),
stub_resolver: StubResolver::new(known_hosts),
disabled_resources: Default::default(),
}
}
@@ -633,7 +649,7 @@ impl ClientState {
// In case the DNS server is a CIDR resource, it needs to go through the tunnel.
if self.is_upstream_set_by_the_portal()
&& self.cidr_resources.longest_match(ip).is_some()
&& self.active_cidr_resources.longest_match(ip).is_some()
{
return Err((packet, ip));
}
@@ -769,6 +785,25 @@ impl ClientState {
self.mangled_dns_queries.clear();
}
pub fn set_disabled_resource(&mut self, new_disabled_resources: HashSet<ResourceId>) {
let current_disabled_resources = self.disabled_resources.clone();
// We set disabled_resources before anything else so that add_resource knows what resources are enabled right now.
self.disabled_resources = new_disabled_resources.clone();
for re_enabled_resource in current_disabled_resources.difference(&new_disabled_resources) {
let Some(resource) = self.resources_by_id.get(re_enabled_resource) else {
continue;
};
self.add_resource(resource.clone());
}
for disabled_resource in &new_disabled_resources {
self.disable_resource(*disabled_resource);
}
}
pub fn dns_mapping(&self) -> BiMap<IpAddr, DnsServer> {
self.dns_mapping.clone()
}
@@ -781,7 +816,7 @@ impl ClientState {
}
fn routes(&self) -> impl Iterator<Item = IpNetwork> + '_ {
self.cidr_resources
self.active_cidr_resources
.iter()
.map(|(ip, _)| ip)
.chain(iter::once(IPV4_RESOURCES.into()))
@@ -789,15 +824,22 @@ impl ClientState {
.chain(self.dns_mapping.left_values().copied().map(Into::into))
}
fn is_resource_enabled(&self, resource: &ResourceId) -> bool {
!self.disabled_resources.contains(resource) && self.resources_by_id.contains_key(resource)
}
fn get_resource_by_destination(&self, destination: IpAddr) -> Option<ResourceId> {
let maybe_dns_resource_id = self.stub_resolver.resolve_resource_by_ip(&destination);
let maybe_cidr_resource_id = self
.cidr_resources
.active_cidr_resources
.longest_match(destination)
.map(|(_, res)| res.id);
maybe_dns_resource_id.or(maybe_cidr_resource_id)
// We need to filter disabled resources because we never remove resources from the stub_resolver
maybe_dns_resource_id
.or(maybe_cidr_resource_id)
.filter(|resource| self.is_resource_enabled(resource))
}
#[must_use]
@@ -955,12 +997,21 @@ impl ClientState {
}
}
self.resources_by_id
.insert(new_resource.id(), new_resource.clone());
if !self.is_resource_enabled(&(new_resource.id())) {
return;
}
match &new_resource {
ResourceDescription::Dns(dns) => {
self.stub_resolver.add_resource(dns.id, dns.address.clone());
}
ResourceDescription::Cidr(cidr) => {
let existing = self.cidr_resources.insert(cidr.address, cidr.clone());
let existing = self
.active_cidr_resources
.insert(cidr.address, cidr.clone());
match existing {
Some(existing) if existing.id != cidr.id => {
@@ -974,15 +1025,18 @@ impl ClientState {
}
ResourceDescription::Internet(_) => {}
}
self.resources_by_id.insert(new_resource.id(), new_resource);
}
#[tracing::instrument(level = "debug", skip_all, fields(?id))]
pub(crate) fn remove_resource(&mut self, id: ResourceId) {
self.disable_resource(id);
self.resources_by_id.remove(&id);
}
fn disable_resource(&mut self, id: ResourceId) {
self.awaiting_connection_details.remove(&id);
self.stub_resolver.remove_resource(id);
self.cidr_resources.retain(|_, r| {
self.active_cidr_resources.retain(|_, r| {
if r.id == id {
tracing::info!(address = %r.address, name = %r.name, "Deactivating CIDR resource");
return false;
@@ -991,8 +1045,6 @@ impl ClientState {
true
});
self.resources_by_id.remove(&id);
let Some(peer) = peer_by_resource_mut(&self.resources_gateways, &mut self.peers, id) else {
return;
};

View File

@@ -190,6 +190,10 @@ impl StubResolver {
Some(domain.clone())
}
fn knows_resource(&self, resource: &ResourceId) -> bool {
self.dns_resources.values().contains(resource)
}
// TODO: we can save a few allocations here still
// We don't need to support multiple questions/qname in a single query because
// nobody does it and since this run with each packet we want to squeeze as much optimization
@@ -229,6 +233,13 @@ impl StubResolver {
let maybe_resource = self.match_resource(&domain);
let resource_records = match (qtype, maybe_resource) {
(_, Some(resource)) if !self.knows_resource(&resource) => {
return Some(ResolveStrategy::ForwardQuery(DnsQuery {
name: domain,
record_type: u16::from(qtype).into(),
query: packet,
}))
}
(Rtype::A, Some(resource)) => self.get_or_assign_a_records(domain.clone(), resource),
(Rtype::AAAA, Some(resource)) => {
self.get_or_assign_aaaa_records(domain.clone(), resource)

View File

@@ -172,6 +172,11 @@ impl ReferenceStateMachine for ReferenceState {
.with(1, Just(Transition::PartitionRelaysFromPortal))
.with(1, Just(Transition::ReconnectPortal))
.with(1, Just(Transition::Idle))
.with_if_not_empty(1, state.client.inner().all_resource_ids(), |resources_id| {
sample::subsequence(resources_id.clone(), resources_id.len()).prop_map(
|resources_id| Transition::DisableResources(HashSet::from_iter(resources_id)),
)
})
.with_if_not_empty(
10,
state.client.inner().ipv4_cidr_resource_dsts(),
@@ -298,10 +303,16 @@ impl ReferenceStateMachine for ReferenceState {
client.cidr_resources.retain(|_, r| &r.id != id);
client.dns_resources.remove(id);
client.connected_cidr_resources.remove(id);
client.connected_dns_resources.retain(|(r, _)| r != id);
client.disconnect_resource(id)
});
}
Transition::DisableResources(resources) => state.client.exec_mut(|client| {
client.disabled_resources = resources.clone();
for id in resources {
client.disconnect_resource(id)
}
}),
Transition::SendDnsQuery {
domain,
r_type,
@@ -439,6 +450,7 @@ impl ReferenceStateMachine for ReferenceState {
true
}
Transition::DisableResources(_) => true,
Transition::SendICMPPacketToNonResourceIp {
dst,
seq,

View File

@@ -266,6 +266,10 @@ pub struct RefClient {
#[derivative(Debug = "ignore")]
pub(crate) connected_dns_resources: HashSet<(ResourceId, DomainName)>,
/// Actively disabled resources by the UI
#[derivative(Debug = "ignore")]
pub(crate) disabled_resources: HashSet<ResourceId>,
/// The expected ICMP handshakes.
///
/// This is indexed by gateway because our assertions rely on the order of the sent packets.
@@ -293,6 +297,11 @@ impl RefClient {
SimClient::new(self.id, client_state)
}
pub(crate) fn disconnect_resource(&mut self, resource: &ResourceId) {
self.connected_cidr_resources.remove(resource);
self.connected_dns_resources.retain(|(r, _)| r != resource);
}
pub(crate) fn reset_connections(&mut self) {
self.connected_cidr_resources.clear();
self.connected_dns_resources.clear();
@@ -367,6 +376,10 @@ impl RefClient {
};
tracing::Span::current().record("resource", tracing::field::display(rid));
if self.disabled_resources.contains(&rid) {
return;
}
let Some(gateway) = gateway_by_resource(rid) else {
tracing::error!("No gateway for resource");
return;
@@ -428,7 +441,9 @@ impl RefClient {
);
tracing::debug!("Not connected to resource, expecting to trigger connection intent");
self.connected_dns_resources.insert((resource, dst));
if !self.disabled_resources.contains(&resource) {
self.connected_dns_resources.insert((resource, dst));
}
}
pub(crate) fn ipv4_cidr_resource_dsts(&self) -> Vec<Ipv4Network> {
@@ -460,7 +475,7 @@ impl RefClient {
.sorted_by_key(|r| r.address.len())
.rev()
.map(|r| r.id)
.next()
.find(|id| !self.disabled_resources.contains(id))
}
fn resolved_domains(&self) -> impl Iterator<Item = (DomainName, HashSet<RecordType>)> + '_ {
@@ -541,7 +556,10 @@ impl RefClient {
}
pub(crate) fn cidr_resource_by_ip(&self, ip: IpAddr) -> Option<ResourceId> {
self.cidr_resources.longest_match(ip).map(|(_, r)| r.id)
self.cidr_resources
.longest_match(ip)
.map(|(_, r)| r.id)
.filter(|id| !self.disabled_resources.contains(id))
}
pub(crate) fn resolved_ip4_for_non_resources(
@@ -736,6 +754,7 @@ fn ref_client(
connected_internet_resources: Default::default(),
expected_icmp_handshakes: Default::default(),
expected_dns_handshakes: Default::default(),
disabled_resources: Default::default(),
},
)
}

View File

@@ -158,6 +158,9 @@ impl StateMachineTest for TunnelTest {
Transition::DeactivateResource(id) => {
state.client.exec_mut(|c| c.sut.remove_resource(id))
}
Transition::DisableResources(resources) => state
.client
.exec_mut(|c| c.sut.set_disabled_resource(resources)),
Transition::SendICMPPacketToNonResourceIp {
src,
dst,

View File

@@ -6,7 +6,7 @@ use connlib_shared::{
use hickory_proto::rr::RecordType;
use proptest::{prelude::*, sample};
use std::{
collections::BTreeMap,
collections::{BTreeMap, HashSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
};
@@ -19,6 +19,8 @@ pub(crate) enum Transition {
ActivateResource(ResourceDescription),
/// Deactivate a resource on the client.
DeactivateResource(ResourceId),
/// Client-side disable resource
DisableResources(HashSet<ResourceId>),
/// Send an ICMP packet to non-resource IP.
SendICMPPacketToNonResourceIp {
src: IpAddr,