feat(connlib): react to config updates (#4322)

* Move the resource changes to `ClientState` to unit test easier
* Add unit tests
* Set new config on update from portal
* Set parameters as told by portal on re-init

Fixes: #2728
This commit is contained in:
Gabi
2024-03-27 22:28:11 -03:00
committed by GitHub
parent fab95483e8
commit f879b430e4
3 changed files with 388 additions and 104 deletions

View File

@@ -25,7 +25,6 @@ use url::Url;
pub struct Eventloop<C: Callbacks> {
tunnel: ClientTunnel<C>,
tunnel_init: bool,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
@@ -50,7 +49,6 @@ impl<C: Callbacks> Eventloop<C> {
Self {
tunnel,
portal,
tunnel_init: false,
connection_intents: SentConnectionIntents::default(),
log_upload_interval: upload_interval(),
rx,
@@ -172,8 +170,10 @@ where
fn handle_portal_inbound_message(&mut self, msg: IngressMessages) {
match msg {
IngressMessages::ConfigChanged(_) => {
tracing::warn!("Config changes are not yet implemented");
IngressMessages::ConfigChanged(config) => {
if let Err(e) = self.tunnel.set_interface(config.interface.clone()) {
tracing::warn!(?config, "Failed to update configuration: {e:?}");
}
}
IngressMessages::IceCandidates(GatewayIceCandidates {
gateway_id,
@@ -187,18 +187,13 @@ where
interface,
resources,
}) => {
if !self.tunnel_init {
if let Err(e) = self.tunnel.set_interface(interface) {
tracing::warn!("Failed to set interface on tunnel: {e}");
return;
}
self.tunnel_init = true;
tracing::info!("Firezone Started!");
let _ = self.tunnel.add_resources(&resources);
} else {
tracing::info!("Firezone reinitializated");
if let Err(e) = self.tunnel.set_interface(interface) {
tracing::warn!("Failed to set interface on tunnel: {e}");
return;
}
tracing::info!("Firezone Started!");
let _ = self.tunnel.set_resources(resources);
}
IngressMessages::ResourceCreatedOrUpdated(resource) => {
let resource_id = resource.id();
@@ -208,7 +203,7 @@ where
}
}
IngressMessages::ResourceDeleted(RemoveResource(resource)) => {
self.tunnel.remove_resource(resource);
self.tunnel.remove_resources(&[resource]);
}
}
}

View File

@@ -140,7 +140,7 @@ impl PartialEq for RequestConnection {
impl Eq for RequestConnection {}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResourceDescription<TDNS = ResourceDescriptionDns> {
Dns(TDNS),
@@ -285,7 +285,7 @@ impl ResourceDescription {
}
/// Description of a resource that maps to a CIDR.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)]
pub struct ResourceDescriptionCidr {
/// Resource's id.
pub id: ResourceId,

View File

@@ -58,111 +58,44 @@ impl<CB> ClientTunnel<CB>
where
CB: Callbacks + 'static,
{
pub fn set_resources(
&mut self,
resources: Vec<ResourceDescription>,
) -> connlib_shared::Result<()> {
self.role_state.set_resources(resources);
self.update_routes()?;
self.update_resource_list();
Ok(())
}
/// Adds a the given resource to the tunnel.
pub fn add_resources(
&mut self,
resources: &[ResourceDescription],
) -> connlib_shared::Result<()> {
for resource_description in resources {
if let Some(resource) = self.role_state.resource_ids.get(&resource_description.id()) {
if resource.has_different_address(resource) {
self.remove_resource(resource.id());
}
}
self.role_state.add_resources(resources);
match &resource_description {
ResourceDescription::Dns(dns) => {
self.role_state
.dns_resources
.insert(dns.address.clone(), dns.clone());
}
ResourceDescription::Cidr(cidr) => {
self.role_state
.cidr_resources
.insert(cidr.address, cidr.clone());
}
}
self.role_state
.resource_ids
.insert(resource_description.id(), resource_description.clone());
}
self.update_resource_list();
self.update_routes()?;
self.update_resource_list();
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(%id))]
pub fn remove_resource(&mut self, id: ResourceId) {
self.role_state.awaiting_connection.remove(&id);
self.role_state
.dns_resources_internal_ips
.retain(|r, _| r.id != id);
self.role_state.dns_resources.retain(|_, r| r.id != id);
self.role_state.cidr_resources.retain(|_, r| r.id != id);
self.role_state
.deferred_dns_queries
.retain(|(r, _), _| r.id != id);
self.role_state.resource_ids.remove(&id);
pub fn remove_resources(&mut self, ids: &[ResourceId]) {
self.role_state.remove_resources(ids);
if let Err(err) = self.update_routes() {
tracing::error!(%id, "Failed to update routes: {err:?}");
tracing::error!(?ids, "Failed to update routes: {err:?}");
}
self.update_resource_list();
let Some(gateway_id) = self.role_state.resources_gateways.remove(&id) else {
tracing::debug!("No gateway associated with resource");
return;
};
let Some(peer) = self.role_state.peers.get_mut(&gateway_id) else {
return;
};
// First we remove the id from all allowed ips
for (network, resources) in peer
.allowed_ips
.iter_mut()
.filter(|(_, resources)| resources.contains(&id))
{
resources.remove(&id);
if !resources.is_empty() {
continue;
}
// If the allowed_ips doesn't correspond to any resource anymore we
// clean up any related translation.
peer.transform
.translations
.remove_by_left(&network.network_address());
}
// We remove all empty allowed ips entry since there's no resource that corresponds to it
peer.allowed_ips.retain(|_, r| !r.is_empty());
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
if peer.allowed_ips.is_empty() {
self.role_state.peers.remove(&gateway_id);
// TODO: should we have a Node::remove_connection?
}
tracing::debug!("Resource removed")
}
fn update_resource_list(&self) {
self.callbacks.on_update_resources(
self.role_state
.resource_ids
.values()
.sorted()
.cloned()
.collect_vec(),
);
self.callbacks
.on_update_resources(self.role_state.resources());
}
/// Updates the system's dns
@@ -373,6 +306,10 @@ impl ClientState {
}
}
fn resources(&self) -> Vec<ResourceDescription> {
self.resource_ids.values().sorted().cloned().collect_vec()
}
pub(crate) fn encapsulate<'s>(
&'s mut self,
packet: MutableIpPacket<'_>,
@@ -916,6 +853,100 @@ impl ClientState {
self.node.poll_transmit()
}
fn set_resources(&mut self, new_resources: Vec<ResourceDescription>) {
self.remove_resources(
&HashSet::from_iter(self.resource_ids.keys().copied())
.difference(&HashSet::<ResourceId>::from_iter(
new_resources.iter().map(|r| r.id()),
))
.copied()
.collect_vec(),
);
self.add_resources(
&HashSet::from_iter(new_resources.iter().cloned())
.difference(&HashSet::<ResourceDescription>::from_iter(
self.resource_ids.values().cloned(),
))
.cloned()
.collect_vec(),
);
}
fn add_resources(&mut self, resources: &[ResourceDescription]) {
for resource_description in resources {
if let Some(resource) = self.resource_ids.get(&resource_description.id()) {
if resource.has_different_address(resource_description) {
self.remove_resources(&[resource.id()]);
}
}
match &resource_description {
ResourceDescription::Dns(dns) => {
self.dns_resources.insert(dns.address.clone(), dns.clone());
}
ResourceDescription::Cidr(cidr) => {
self.cidr_resources.insert(cidr.address, cidr.clone());
}
}
self.resource_ids
.insert(resource_description.id(), resource_description.clone());
}
}
#[tracing::instrument(level = "debug", skip_all, fields(?ids))]
fn remove_resources(&mut self, ids: &[ResourceId]) {
for id in ids {
self.awaiting_connection.remove(id);
self.dns_resources_internal_ips.retain(|r, _| r.id != *id);
self.dns_resources.retain(|_, r| r.id != *id);
self.cidr_resources.retain(|_, r| r.id != *id);
self.deferred_dns_queries.retain(|(r, _), _| r.id != *id);
self.resource_ids.remove(id);
let Some(gateway_id) = self.resources_gateways.remove(id) else {
tracing::debug!("No gateway associated with resource");
continue;
};
let Some(peer) = self.peers.get_mut(&gateway_id) else {
continue;
};
// First we remove the id from all allowed ips
for (network, resources) in peer
.allowed_ips
.iter_mut()
.filter(|(_, resources)| resources.contains(id))
{
resources.remove(id);
if !resources.is_empty() {
continue;
}
// If the allowed_ips doesn't correspond to any resource anymore we
// clean up any related translation.
peer.transform
.translations
.remove_by_left(&network.network_address());
}
// We remove all empty allowed ips entry since there's no resource that corresponds to it
peer.allowed_ips.retain(|_, r| !r.is_empty());
// If there's no allowed ip left we remove the whole peer because there's no point on keeping it around
if peer.allowed_ips.is_empty() {
self.peers.remove(&gateway_id);
// TODO: should we have a Node::remove_connection?
}
}
tracing::debug!("Resources removed")
}
fn update_dns_mapping(&mut self) -> bool {
let Some(config) = &self.interface_config else {
return false;
@@ -1257,6 +1288,203 @@ mod tests {
)
}
#[test]
fn add_resources_works() {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset(
[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com")
]
.iter()
)
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()])
);
client_state.add_resources(&[cidr_baz_resource("11.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset(
[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
cidr_baz_resource("11.0.0.0/24")
]
.iter()
)
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![
IpNetwork::from_str("10.0.0.0/24").unwrap(),
IpNetwork::from_str("11.0.0.0/24").unwrap()
])
);
}
#[test]
fn add_resources_update_works_cidr() {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.add_resources(&[cidr_foo_resource("11.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset(
[
cidr_foo_resource("11.0.0.0/24"),
dns_bar_resource("baz.com")
]
.iter()
)
);
assert_eq!(
HashSet::<IpNetwork>::from_iter(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()])
);
}
#[test]
fn add_resources_update_works_to_dns() {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.add_resources(&[cidr_bar_id("11.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset([cidr_bar_id("11.0.0.0/24"), cidr_foo_resource("10.0.0.0/24")].iter())
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![
IpNetwork::from_str("10.0.0.0/24").unwrap(),
IpNetwork::from_str("11.0.0.0/24").unwrap()
])
);
}
#[test]
fn remove_resources_works() {
let mut client_state = ClientState::for_test();
client_state.add_resources(&[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.remove_resources(&[cidr_foo_id()]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset([dns_bar_resource("baz.com")].iter())
);
assert_eq!(hashset(client_state.routes()), expected_routes(vec![]));
}
#[test]
fn set_resource_works() {
let mut client_state = ClientState::for_test();
client_state.set_resources(vec![
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset(
[
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com")
]
.iter()
)
);
assert_eq!(
HashSet::<IpNetwork>::from_iter(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()])
);
}
#[test]
fn set_resource_replaces_old_resources() {
let mut client_state = ClientState::for_test();
client_state.set_resources(vec![
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.set_resources(vec![cidr_baz_resource("11.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset([cidr_baz_resource("11.0.0.0/24")].iter())
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()])
);
}
#[test]
fn set_resource_updates_old_resource_with_same_id() {
let mut client_state = ClientState::for_test();
client_state.set_resources(vec![
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.set_resources(vec![cidr_foo_resource("11.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset([cidr_foo_resource("11.0.0.0/24")].iter())
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("11.0.0.0/24").unwrap()])
);
}
#[test]
fn set_resource_keeps_resource_if_unchanged() {
let mut client_state = ClientState::for_test();
client_state.set_resources(vec![
cidr_foo_resource("10.0.0.0/24"),
dns_bar_resource("baz.com"),
]);
client_state.set_resources(vec![cidr_foo_resource("10.0.0.0/24")]);
assert_eq!(
hashset(client_state.resources().iter()),
hashset([cidr_foo_resource("10.0.0.0/24")].iter())
);
assert_eq!(
hashset(client_state.routes()),
expected_routes(vec![IpNetwork::from_str("10.0.0.0/24").unwrap()])
);
}
impl ClientState {
fn for_test() -> ClientState {
ClientState::new(StaticSecret::random_from_rng(OsRng))
@@ -1307,7 +1535,68 @@ mod tests {
})
}
fn cidr_foo_resource(addr: &str) -> ResourceDescription {
ResourceDescription::Cidr(ResourceDescriptionCidr {
id: cidr_foo_id(),
address: addr.parse().unwrap(),
name: "foo".to_string(),
})
}
fn cidr_bar_id(addr: &str) -> ResourceDescription {
ResourceDescription::Cidr(ResourceDescriptionCidr {
id: dns_bar_id(),
address: addr.parse().unwrap(),
name: "foo".to_string(),
})
}
fn dns_bar_resource(addr: &str) -> ResourceDescription {
ResourceDescription::Dns(ResourceDescriptionDns {
id: dns_bar_id(),
address: addr.to_string(),
name: "bar".to_string(),
})
}
fn cidr_baz_resource(addr: &str) -> ResourceDescription {
ResourceDescription::Cidr(ResourceDescriptionCidr {
id: cidr_baz_id(),
address: addr.parse().unwrap(),
name: "baz".to_string(),
})
}
fn cidr_foo_id() -> ResourceId {
resource_id("fb51081a-2e06-4b59-b5a8-33592de9ebb1")
}
fn cidr_baz_id() -> ResourceId {
resource_id("4e0bf4ea-4175-4cdb-a7c2-cbeffa8ccc5d")
}
fn dns_bar_id() -> ResourceId {
resource_id("868483b6-431e-484d-bdd6-dad60ed26418")
}
fn ip(addr: &str) -> IpAddr {
addr.parse().unwrap()
}
fn resource_id(id: &str) -> ResourceId {
id.parse().unwrap()
}
fn expected_routes(resource_routes: Vec<IpNetwork>) -> HashSet<IpNetwork> {
HashSet::from_iter(
resource_routes
.into_iter()
.chain(iter::once(IpNetwork::from_str(IPV4_RESOURCES).unwrap()))
.chain(iter::once(IpNetwork::from_str(IPV6_RESOURCES).unwrap())),
)
}
fn hashset<T: std::hash::Hash + Eq>(val: impl Iterator<Item = T>) -> HashSet<T> {
HashSet::from_iter(val)
}
}