diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 0b5f48eaf..d21cfce83 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -15,7 +15,6 @@ use ip_network::IpNetwork; use ip_network_table::IpNetworkTable; use itertools::Itertools; -use crate::device_channel::Device; use crate::utils::{earliest, stun, turn}; use crate::{ClientEvent, ClientTunnel}; use secrecy::{ExposeSecret as _, Secret}; @@ -250,193 +249,38 @@ where gateway_id: GatewayId, relays: Vec, ) -> connlib_shared::Result { - tracing::trace!("request_connection"); - - if let Some(connection) = self - .role_state - .attempt_to_reuse_connection(resource_id, gateway_id)? - { - // TODO: now we send reuse connections before connection is established but after - // response is offered. - // We need to consider new race conditions, such as connection failed after - // reuse connection is sent. - // Though I believe everything will work just fine like this. - return Ok(Request::ReuseConnection(connection)); - } - - if self.role_state.node.is_expecting_answer(gateway_id) { - return Err(Error::PendingConnection); - } - - let awaiting_connection = self - .role_state - .get_awaiting_connection(&resource_id)? - .clone(); - - let offer = self.role_state.node.new_connection( + self.role_state.create_or_reuse_connection( + resource_id, gateway_id, stun(&relays, |addr| self.io.sockets_ref().can_handle(addr)), turn(&relays, |addr| self.io.sockets_ref().can_handle(addr)), - awaiting_connection.last_intent_sent_at, - Instant::now(), - ); - - Ok(Request::NewConnection(RequestConnection { - resource_id, - gateway_id, - client_preshared_key: Secret::new(Key(*offer.session_key.expose_secret())), - client_payload: ClientPayload { - ice_parameters: Offer { - username: offer.credentials.username, - password: offer.credentials.password, - }, - domain: awaiting_connection.domain, - }, - })) - } - - fn new_peer( - &mut self, - resource_id: ResourceId, - gateway_id: GatewayId, - domain_response: Option, - ) -> connlib_shared::Result<()> { - let ips = self.role_state.create_peer_config_for_new_connection( - resource_id, - &domain_response.as_ref().map(|d| d.domain.clone()), - )?; - - let resource_ids = HashSet::from([resource_id]); - let mut peer: Peer<_, PacketTransformClient, _> = - Peer::new(gateway_id, Default::default(), &ips, resource_ids); - peer.transform.set_dns(self.role_state.dns_mapping()); - self.role_state.peers.insert(peer, &[]); - - let peer_ips = if let Some(domain_response) = domain_response { - self.dns_response(&resource_id, &domain_response, &gateway_id)? - } else { - ips - }; - - self.role_state - .peers - .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); - - Ok(()) + ) } /// Called when a response to [ClientTunnel::request_connection] is ready. /// /// Once this is called, if everything goes fine, a new tunnel should be started between the 2 peers. - #[tracing::instrument(level = "trace", skip(self, gateway_public_key, resource_id))] pub fn received_offer_response( &mut self, resource_id: ResourceId, - rtc_ice_params: Answer, + answer: Answer, domain_response: Option, gateway_public_key: PublicKey, ) -> connlib_shared::Result<()> { - tracing::trace!("received offer response"); - - let gateway_id = self - .role_state - .gateway_by_resource(&resource_id) - .ok_or(Error::UnknownResource)?; - - self.role_state.node.accept_answer( - gateway_id, - gateway_public_key, - snownet::Answer { - credentials: snownet::Credentials { - username: rtc_ice_params.username, - password: rtc_ice_params.password, - }, - }, - Instant::now(), - ); - - self.new_peer(resource_id, gateway_id, domain_response)?; + self.role_state + .accept_answer(answer, resource_id, gateway_public_key, domain_response)?; Ok(()) } - fn dns_response( - &mut self, - resource_id: &ResourceId, - domain_response: &DomainResponse, - peer_id: &GatewayId, - ) -> connlib_shared::Result> { - let peer = self - .role_state - .peers - .get_mut(peer_id) - .ok_or(Error::ControlProtocolError)?; - - let resource_description = self - .role_state - .resource_ids - .get(resource_id) - .ok_or(Error::UnknownResource)? - .clone(); - - let ResourceDescription::Dns(resource_description) = resource_description else { - // We should never get a domain_response for a CIDR resource! - return Err(Error::ControlProtocolError); - }; - - let resource_description = - DnsResource::from_description(&resource_description, domain_response.domain.clone()); - - let addrs: HashSet<_> = domain_response - .address - .iter() - .filter_map(|external_ip| { - peer.transform - .get_or_assign_translation(external_ip, &mut self.role_state.ip_provider) - }) - .collect(); - - self.role_state - .dns_resources_internal_ips - .insert(resource_description.clone(), addrs.clone()); - - let ips: Vec = addrs.iter().copied().map(Into::into).collect(); - - send_dns_answer( - &mut self.role_state, - Rtype::Aaaa, - self.io.device_mut(), - &resource_description, - &addrs, - ); - - send_dns_answer( - &mut self.role_state, - Rtype::A, - self.io.device_mut(), - &resource_description, - &addrs, - ); - - Ok(ips) - } - #[tracing::instrument(level = "trace", skip(self, resource_id))] pub fn received_domain_parameters( &mut self, resource_id: ResourceId, domain_response: DomainResponse, ) -> connlib_shared::Result<()> { - let gateway_id = self - .role_state - .gateway_by_resource(&resource_id) - .ok_or(Error::UnknownResource)?; - - let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?; - self.role_state - .peers - .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); + .received_domain_parameters(resource_id, domain_response)?; Ok(()) } @@ -451,7 +295,6 @@ pub enum Request { fn send_dns_answer( role_state: &mut ClientState, qtype: Rtype, - device: &Device, resource_description: &DnsResource, addrs: &HashSet, ) { @@ -462,9 +305,7 @@ fn send_dns_answer( let Some(packet) = dns::create_local_answer(addrs, packet) else { return; }; - if let Err(e) = device.write(packet) { - tracing::error!(err = ?e, "error writing packet: {e:#?}"); - } + role_state.buffered_packets.push_back(packet); } } @@ -593,6 +434,188 @@ impl ClientState { Some(packet.into_immutable()) } + #[tracing::instrument(level = "trace", skip_all, fields(%resource_id))] + fn accept_answer( + &mut self, + answer: Answer, + resource_id: ResourceId, + gateway: PublicKey, + domain_response: Option, + ) -> connlib_shared::Result<()> { + let gateway_id = self + .gateway_by_resource(&resource_id) + .ok_or(Error::UnknownResource)?; + + self.node.accept_answer( + gateway_id, + gateway, + snownet::Answer { + credentials: snownet::Credentials { + username: answer.username, + password: answer.password, + }, + }, + Instant::now(), + ); + + let desc = self + .resource_ids + .get(&resource_id) + .ok_or(Error::ControlProtocolError)?; + + let ips = self.get_resource_ip(desc, &domain_response.as_ref().map(|d| d.domain.clone())); + + // Tidy up state once everything succeeded. + self.awaiting_connection.remove(&resource_id); + + let resource_ids = HashSet::from([resource_id]); + let mut peer: Peer<_, PacketTransformClient, _> = + Peer::new(gateway_id, Default::default(), &ips, resource_ids); + peer.transform.set_dns(self.dns_mapping()); + self.peers.insert(peer, &[]); + + let peer_ips = if let Some(domain_response) = domain_response { + self.dns_response(&resource_id, &domain_response, &gateway_id)? + } else { + ips + }; + + self.peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); + + Ok(()) + } + + fn create_or_reuse_connection( + &mut self, + resource_id: ResourceId, + gateway_id: GatewayId, + allowed_stun_servers: HashSet, + allowed_turn_servers: HashSet<(SocketAddr, String, String, String)>, + ) -> connlib_shared::Result { + tracing::trace!("request_connection"); + + let desc = self + .resource_ids + .get(&resource_id) + .ok_or(Error::UnknownResource)?; + + let domain = self.get_awaiting_connection(&resource_id)?.domain.clone(); + + if self.is_connected_to(resource_id, &domain) { + return Err(Error::UnexpectedConnectionDetails); + } + + let awaiting_connection = self + .awaiting_connection + .get(&resource_id) + .ok_or(Error::UnexpectedConnectionDetails)? + .clone(); + + self.resources_gateways.insert(resource_id, gateway_id); + + if self.peers.get(&gateway_id).is_some() { + self.peers.add_ips_with_resource( + &gateway_id, + &self.get_resource_ip(desc, &domain), + &resource_id, + ); + + self.awaiting_connection.remove(&resource_id); + + return Ok(Request::ReuseConnection(ReuseConnection { + resource_id, + gateway_id, + payload: domain.clone(), + })); + }; + + if self.node.is_expecting_answer(gateway_id) { + return Err(Error::PendingConnection); + } + + let offer = self.node.new_connection( + gateway_id, + allowed_stun_servers, + allowed_turn_servers, + awaiting_connection.last_intent_sent_at, + Instant::now(), + ); + + return Ok(Request::NewConnection(RequestConnection { + resource_id, + gateway_id, + client_preshared_key: Secret::new(Key(*offer.session_key.expose_secret())), + client_payload: ClientPayload { + ice_parameters: Offer { + username: offer.credentials.username, + password: offer.credentials.password, + }, + domain: awaiting_connection.domain, + }, + })); + } + + fn received_domain_parameters( + &mut self, + resource_id: ResourceId, + domain_response: DomainResponse, + ) -> connlib_shared::Result<()> { + let gateway_id = self + .gateway_by_resource(&resource_id) + .ok_or(Error::UnknownResource)?; + + let peer_ips = self.dns_response(&resource_id, &domain_response, &gateway_id)?; + + self.peers + .add_ips_with_resource(&gateway_id, &peer_ips, &resource_id); + + Ok(()) + } + + fn dns_response( + &mut self, + resource_id: &ResourceId, + domain_response: &DomainResponse, + peer_id: &GatewayId, + ) -> connlib_shared::Result> { + let peer = self + .peers + .get_mut(peer_id) + .ok_or(Error::ControlProtocolError)?; + + let resource_description = self + .resource_ids + .get(resource_id) + .ok_or(Error::UnknownResource)? + .clone(); + + let ResourceDescription::Dns(resource_description) = resource_description else { + // We should never get a domain_response for a CIDR resource! + return Err(Error::ControlProtocolError); + }; + + let resource_description = + DnsResource::from_description(&resource_description, domain_response.domain.clone()); + + let addrs: HashSet<_> = domain_response + .address + .iter() + .filter_map(|external_ip| { + peer.transform + .get_or_assign_translation(external_ip, &mut self.ip_provider) + }) + .collect(); + + self.dns_resources_internal_ips + .insert(resource_description.clone(), addrs.clone()); + + send_dns_answer(self, Rtype::Aaaa, &resource_description, &addrs); + send_dns_answer(self, Rtype::A, &resource_description, &addrs); + + Ok(addrs.iter().copied().map(Into::into).collect()) + } + /// Attempt to handle the given packet as a DNS packet. /// /// Returns `Ok` if the packet is in fact a DNS query with an optional response to send back. @@ -651,44 +674,6 @@ impl ClientState { .ok_or(Error::UnexpectedConnectionDetails) } - pub(crate) fn attempt_to_reuse_connection( - &mut self, - resource: ResourceId, - gateway: GatewayId, - ) -> Result, ConnlibError> { - let desc = self - .resource_ids - .get(&resource) - .ok_or(Error::UnknownResource)?; - - let domain = self.get_awaiting_connection(&resource)?.domain.clone(); - - if self.is_connected_to(resource, &domain) { - return Err(Error::UnexpectedConnectionDetails); - } - - self.awaiting_connection - .get_mut(&resource) - .ok_or(Error::UnexpectedConnectionDetails)?; - - self.resources_gateways.insert(resource, gateway); - - if self.peers.get(&gateway).is_none() { - return Ok(None); - }; - - self.peers - .add_ips_with_resource(&gateway, &self.get_resource_ip(desc, &domain), &resource); - - self.awaiting_connection.remove(&resource); - - Ok(Some(ReuseConnection { - resource_id: resource, - gateway_id: gateway, - payload: domain.clone(), - })) - } - pub fn on_connection_failed(&mut self, resource: ResourceId) { self.awaiting_connection.remove(&resource); self.resources_gateways.remove(&resource); @@ -769,24 +754,6 @@ impl ClientState { }); } - pub fn create_peer_config_for_new_connection( - &mut self, - resource: ResourceId, - domain: &Option, - ) -> Result, ConnlibError> { - let desc = self - .resource_ids - .get(&resource) - .ok_or(Error::ControlProtocolError)?; - - let ips = self.get_resource_ip(desc, domain); - - // Tidy up state once everything succeeded. - self.awaiting_connection.remove(&resource); - - Ok(ips) - } - pub fn gateway_by_resource(&self, resource: &ResourceId) -> Option { self.resources_gateways.get(resource).copied() }