refactor(connlib): move functionality onto ClientState (#4167)

With the move to SANS-IO, we will be able to write deterministic unit
tests for the tunnel logic. To actually do that, `ClientState` and
`GatewayState` need to encapsulate all the logic that we want to test.

This PR does some minor refactoring on the functions on `ClientTunnel`
and moves several of them onto `ClientState`. It doesn't touch
`add_resources` and `remove_resource` because those depend on #4156.
This commit is contained in:
Thomas Eizinger
2024-03-19 09:54:20 +10:00
committed by GitHub
parent 5a0aff1d75
commit 4e48884513

View File

@@ -15,7 +15,6 @@ use ip_network::IpNetwork;
use ip_network_table::IpNetworkTable;
use itertools::Itertools;
use crate::device_channel::Device;
use crate::utils::{earliest, stun, turn};
use crate::{ClientEvent, ClientTunnel};
use secrecy::{ExposeSecret as _, Secret};
@@ -250,193 +249,38 @@ where
gateway_id: GatewayId,
relays: Vec<Relay>,
) -> connlib_shared::Result<Request> {
tracing::trace!("request_connection");
if let Some(connection) = self
.role_state
.attempt_to_reuse_connection(resource_id, gateway_id)?
{
// TODO: now we send reuse connections before connection is established but after
// response is offered.
// We need to consider new race conditions, such as connection failed after
// reuse connection is sent.
// Though I believe everything will work just fine like this.
return Ok(Request::ReuseConnection(connection));
}
if self.role_state.node.is_expecting_answer(gateway_id) {
return Err(Error::PendingConnection);
}
let awaiting_connection = self
.role_state
.get_awaiting_connection(&resource_id)?
.clone();
let offer = self.role_state.node.new_connection(
self.role_state.create_or_reuse_connection(
resource_id,
gateway_id,
stun(&relays, |addr| self.io.sockets_ref().can_handle(addr)),
turn(&relays, |addr| self.io.sockets_ref().can_handle(addr)),
awaiting_connection.last_intent_sent_at,
Instant::now(),
);
Ok(Request::NewConnection(RequestConnection {
resource_id,
gateway_id,
client_preshared_key: Secret::new(Key(*offer.session_key.expose_secret())),
client_payload: ClientPayload {
ice_parameters: Offer {
username: offer.credentials.username,
password: offer.credentials.password,
},
domain: awaiting_connection.domain,
},
}))
}
fn new_peer(
&mut self,
resource_id: ResourceId,
gateway_id: GatewayId,
domain_response: Option<DomainResponse>,
) -> connlib_shared::Result<()> {
let ips = self.role_state.create_peer_config_for_new_connection(
resource_id,
&domain_response.as_ref().map(|d| d.domain.clone()),
)?;
let resource_ids = HashSet::from([resource_id]);
let mut peer: Peer<_, PacketTransformClient, _> =
Peer::new(gateway_id, Default::default(), &ips, resource_ids);
peer.transform.set_dns(self.role_state.dns_mapping());
self.role_state.peers.insert(peer, &[]);
let peer_ips = if let Some(domain_response) = domain_response {
self.dns_response(&resource_id, &domain_response, &gateway_id)?
} else {
ips
};
self.role_state
.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
Ok(())
)
}
/// Called when a response to [ClientTunnel::request_connection] is ready.
///
/// Once this is called, if everything goes fine, a new tunnel should be started between the 2 peers.
#[tracing::instrument(level = "trace", skip(self, gateway_public_key, resource_id))]
pub fn received_offer_response(
&mut self,
resource_id: ResourceId,
rtc_ice_params: Answer,
answer: Answer,
domain_response: Option<DomainResponse>,
gateway_public_key: PublicKey,
) -> connlib_shared::Result<()> {
tracing::trace!("received offer response");
let gateway_id = self
.role_state
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
self.role_state.node.accept_answer(
gateway_id,
gateway_public_key,
snownet::Answer {
credentials: snownet::Credentials {
username: rtc_ice_params.username,
password: rtc_ice_params.password,
},
},
Instant::now(),
);
self.new_peer(resource_id, gateway_id, domain_response)?;
self.role_state
.accept_answer(answer, resource_id, gateway_public_key, domain_response)?;
Ok(())
}
fn dns_response(
&mut self,
resource_id: &ResourceId,
domain_response: &DomainResponse,
peer_id: &GatewayId,
) -> connlib_shared::Result<Vec<IpNetwork>> {
let peer = self
.role_state
.peers
.get_mut(peer_id)
.ok_or(Error::ControlProtocolError)?;
let resource_description = self
.role_state
.resource_ids
.get(resource_id)
.ok_or(Error::UnknownResource)?
.clone();
let ResourceDescription::Dns(resource_description) = resource_description else {
// We should never get a domain_response for a CIDR resource!
return Err(Error::ControlProtocolError);
};
let resource_description =
DnsResource::from_description(&resource_description, domain_response.domain.clone());
let addrs: HashSet<_> = domain_response
.address
.iter()
.filter_map(|external_ip| {
peer.transform
.get_or_assign_translation(external_ip, &mut self.role_state.ip_provider)
})
.collect();
self.role_state
.dns_resources_internal_ips
.insert(resource_description.clone(), addrs.clone());
let ips: Vec<IpNetwork> = addrs.iter().copied().map(Into::into).collect();
send_dns_answer(
&mut self.role_state,
Rtype::Aaaa,
self.io.device_mut(),
&resource_description,
&addrs,
);
send_dns_answer(
&mut self.role_state,
Rtype::A,
self.io.device_mut(),
&resource_description,
&addrs,
);
Ok(ips)
}
#[tracing::instrument(level = "trace", skip(self, resource_id))]
pub fn received_domain_parameters(
&mut self,
resource_id: ResourceId,
domain_response: DomainResponse,
) -> connlib_shared::Result<()> {
let gateway_id = self
.role_state
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
self.role_state
.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
.received_domain_parameters(resource_id, domain_response)?;
Ok(())
}
@@ -451,7 +295,6 @@ pub enum Request {
fn send_dns_answer(
role_state: &mut ClientState,
qtype: Rtype,
device: &Device,
resource_description: &DnsResource,
addrs: &HashSet<IpAddr>,
) {
@@ -462,9 +305,7 @@ fn send_dns_answer(
let Some(packet) = dns::create_local_answer(addrs, packet) else {
return;
};
if let Err(e) = device.write(packet) {
tracing::error!(err = ?e, "error writing packet: {e:#?}");
}
role_state.buffered_packets.push_back(packet);
}
}
@@ -593,6 +434,188 @@ impl ClientState {
Some(packet.into_immutable())
}
#[tracing::instrument(level = "trace", skip_all, fields(%resource_id))]
fn accept_answer(
&mut self,
answer: Answer,
resource_id: ResourceId,
gateway: PublicKey,
domain_response: Option<DomainResponse>,
) -> connlib_shared::Result<()> {
let gateway_id = self
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
self.node.accept_answer(
gateway_id,
gateway,
snownet::Answer {
credentials: snownet::Credentials {
username: answer.username,
password: answer.password,
},
},
Instant::now(),
);
let desc = self
.resource_ids
.get(&resource_id)
.ok_or(Error::ControlProtocolError)?;
let ips = self.get_resource_ip(desc, &domain_response.as_ref().map(|d| d.domain.clone()));
// Tidy up state once everything succeeded.
self.awaiting_connection.remove(&resource_id);
let resource_ids = HashSet::from([resource_id]);
let mut peer: Peer<_, PacketTransformClient, _> =
Peer::new(gateway_id, Default::default(), &ips, resource_ids);
peer.transform.set_dns(self.dns_mapping());
self.peers.insert(peer, &[]);
let peer_ips = if let Some(domain_response) = domain_response {
self.dns_response(&resource_id, &domain_response, &gateway_id)?
} else {
ips
};
self.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
Ok(())
}
fn create_or_reuse_connection(
&mut self,
resource_id: ResourceId,
gateway_id: GatewayId,
allowed_stun_servers: HashSet<SocketAddr>,
allowed_turn_servers: HashSet<(SocketAddr, String, String, String)>,
) -> connlib_shared::Result<Request> {
tracing::trace!("request_connection");
let desc = self
.resource_ids
.get(&resource_id)
.ok_or(Error::UnknownResource)?;
let domain = self.get_awaiting_connection(&resource_id)?.domain.clone();
if self.is_connected_to(resource_id, &domain) {
return Err(Error::UnexpectedConnectionDetails);
}
let awaiting_connection = self
.awaiting_connection
.get(&resource_id)
.ok_or(Error::UnexpectedConnectionDetails)?
.clone();
self.resources_gateways.insert(resource_id, gateway_id);
if self.peers.get(&gateway_id).is_some() {
self.peers.add_ips_with_resource(
&gateway_id,
&self.get_resource_ip(desc, &domain),
&resource_id,
);
self.awaiting_connection.remove(&resource_id);
return Ok(Request::ReuseConnection(ReuseConnection {
resource_id,
gateway_id,
payload: domain.clone(),
}));
};
if self.node.is_expecting_answer(gateway_id) {
return Err(Error::PendingConnection);
}
let offer = self.node.new_connection(
gateway_id,
allowed_stun_servers,
allowed_turn_servers,
awaiting_connection.last_intent_sent_at,
Instant::now(),
);
return Ok(Request::NewConnection(RequestConnection {
resource_id,
gateway_id,
client_preshared_key: Secret::new(Key(*offer.session_key.expose_secret())),
client_payload: ClientPayload {
ice_parameters: Offer {
username: offer.credentials.username,
password: offer.credentials.password,
},
domain: awaiting_connection.domain,
},
}));
}
fn received_domain_parameters(
&mut self,
resource_id: ResourceId,
domain_response: DomainResponse,
) -> connlib_shared::Result<()> {
let gateway_id = self
.gateway_by_resource(&resource_id)
.ok_or(Error::UnknownResource)?;
let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?;
self.peers
.add_ips_with_resource(&gateway_id, &peer_ips, &resource_id);
Ok(())
}
fn dns_response(
&mut self,
resource_id: &ResourceId,
domain_response: &DomainResponse,
peer_id: &GatewayId,
) -> connlib_shared::Result<Vec<IpNetwork>> {
let peer = self
.peers
.get_mut(peer_id)
.ok_or(Error::ControlProtocolError)?;
let resource_description = self
.resource_ids
.get(resource_id)
.ok_or(Error::UnknownResource)?
.clone();
let ResourceDescription::Dns(resource_description) = resource_description else {
// We should never get a domain_response for a CIDR resource!
return Err(Error::ControlProtocolError);
};
let resource_description =
DnsResource::from_description(&resource_description, domain_response.domain.clone());
let addrs: HashSet<_> = domain_response
.address
.iter()
.filter_map(|external_ip| {
peer.transform
.get_or_assign_translation(external_ip, &mut self.ip_provider)
})
.collect();
self.dns_resources_internal_ips
.insert(resource_description.clone(), addrs.clone());
send_dns_answer(self, Rtype::Aaaa, &resource_description, &addrs);
send_dns_answer(self, Rtype::A, &resource_description, &addrs);
Ok(addrs.iter().copied().map(Into::into).collect())
}
/// Attempt to handle the given packet as a DNS packet.
///
/// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back.
@@ -651,44 +674,6 @@ impl ClientState {
.ok_or(Error::UnexpectedConnectionDetails)
}
pub(crate) fn attempt_to_reuse_connection(
&mut self,
resource: ResourceId,
gateway: GatewayId,
) -> Result<Option<ReuseConnection>, ConnlibError> {
let desc = self
.resource_ids
.get(&resource)
.ok_or(Error::UnknownResource)?;
let domain = self.get_awaiting_connection(&resource)?.domain.clone();
if self.is_connected_to(resource, &domain) {
return Err(Error::UnexpectedConnectionDetails);
}
self.awaiting_connection
.get_mut(&resource)
.ok_or(Error::UnexpectedConnectionDetails)?;
self.resources_gateways.insert(resource, gateway);
if self.peers.get(&gateway).is_none() {
return Ok(None);
};
self.peers
.add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource);
self.awaiting_connection.remove(&resource);
Ok(Some(ReuseConnection {
resource_id: resource,
gateway_id: gateway,
payload: domain.clone(),
}))
}
pub fn on_connection_failed(&mut self, resource: ResourceId) {
self.awaiting_connection.remove(&resource);
self.resources_gateways.remove(&resource);
@@ -769,24 +754,6 @@ impl ClientState {
});
}
pub fn create_peer_config_for_new_connection(
&mut self,
resource: ResourceId,
domain: &Option<Dname>,
) -> Result<Vec<IpNetwork>, ConnlibError> {
let desc = self
.resource_ids
.get(&resource)
.ok_or(Error::ControlProtocolError)?;
let ips = self.get_resource_ip(desc, domain);
// Tidy up state once everything succeeded.
self.awaiting_connection.remove(&resource);
Ok(ips)
}
pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option<GatewayId> {
self.resources_gateways.get(resource).copied()
}