From 2d4818e0070b80bd2f1f4ad4709a6963ff63acbe Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 8 Oct 2024 09:28:51 +1100 Subject: [PATCH] refactor(connlib): rotate tunnel private key on `reset` (#6909) With the new control protocol specified in #6461, the client will no longer initiate new connections. Instead, the credentials are generated deterministically by the portal based on the gateway's and the client's public key. For as long as they use the same public key, they also have the same in-memory state which makes creating connections idempotent. What we didn't consider in the new design at first is that when clients roam, they discard all connections but keep the same private key. As a result, the portal would generate the same ICE credentials which means the gateway thinks it can reuse the existing connection when new flows get authorized. The client however discarded all connections (and rotated its ports and maybe IPs), meaning the previous candidates sent to the gateway are no longer valid and connectivity fails. We fix this by also rotating the private keys upon reset. Rotating the keys itself isn't enough, we also need to propagate the new public key all the way "over" to the phoenix channel component which lives separately from connlib's data plane. To achieve this, we change `PhoenixChannel` to now start in the "disconnected" state and require an explicit `connect` call. In addition, the `LoginUrl` constructed by various components now acts merely as a "prototype", which may require additional data to construct a fully valid URL. In the case of client and gateway, this is the public key of the `Node`. This additional parameter needs to be passed to `PhoenixChannel` in the `connect` call, thus forming a type-safe contract that ensures we never attempt to connect without providing a public key. For the relay, this doesn't apply. Lastly, this allows us to tidy up the code a bit by: a) generating the `Node`'s private key from the existing RNG b) removing `ConnectArgs` which only had two members left Related: #6461. Related: #6732. --- rust/connlib/clients/android/src/lib.rs | 22 +++--- rust/connlib/clients/apple/src/lib.rs | 26 +++----- rust/connlib/clients/shared/src/eventloop.rs | 11 +-- rust/connlib/clients/shared/src/lib.rs | 40 +++++------ rust/connlib/snownet/src/node.rs | 12 +++- rust/connlib/snownet/tests/lib.rs | 21 ++---- rust/connlib/tunnel/src/client.rs | 18 ++--- rust/connlib/tunnel/src/gateway.rs | 7 +- rust/connlib/tunnel/src/lib.rs | 23 +++---- rust/connlib/tunnel/src/tests/reference.rs | 4 +- rust/connlib/tunnel/src/tests/sim_client.rs | 2 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 2 +- rust/gateway/src/eventloop.rs | 8 ++- rust/gateway/src/main.rs | 19 ++---- rust/headless-client/src/ipc_service.rs | 15 ++--- rust/headless-client/src/main.rs | 20 +++--- rust/phoenix-channel/src/lib.rs | 70 +++++++++++--------- rust/phoenix-channel/src/login_url.rs | 69 ++++++++++++++----- rust/relay/src/main.rs | 13 ++-- 19 files changed, 200 insertions(+), 202 deletions(-) diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 7ec6b734e..165740c42 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -5,9 +5,7 @@ use crate::tun::Tun; use backoff::ExponentialBackoffBuilder; -use connlib_client_shared::{ - keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList, -}; +use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; use connlib_model::{ResourceId, ResourceView}; use ip_network::{Ipv4Network, Ipv6Network}; use jni::{ @@ -355,13 +353,11 @@ fn connect( handle, }; - let (private_key, public_key) = keypair(); let url = LoginUrl::client( api_url.as_str(), &secret, device_id, Some(device_name), - public_key.to_bytes(), device_info, )?; @@ -374,13 +370,7 @@ fn connect( let tcp_socket_factory = Arc::new(protected_tcp_socket_factory(callbacks.clone())); - let args = ConnectArgs { - tcp_socket_factory: tcp_socket_factory.clone(), - udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())), - private_key, - callbacks, - }; - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(Some(os_version), env!("CARGO_PKG_VERSION")), "client", @@ -390,7 +380,13 @@ fn connect( .build(), tcp_socket_factory, )?; - let session = Session::connect(args, portal, runtime.handle().clone()); + let session = Session::connect( + Arc::new(protected_tcp_socket_factory(callbacks.clone())), + Arc::new(protected_udp_socket_factory(callbacks.clone())), + callbacks, + portal, + runtime.handle().clone(), + ); Ok(SessionWrapper { inner: session, diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 4e1bcc0b5..3817f1053 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -6,9 +6,7 @@ mod tun; use anyhow::Result; use backoff::ExponentialBackoffBuilder; -use connlib_client_shared::{ - keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList, -}; +use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; use connlib_model::ResourceView; use ip_network::{Ipv4Network, Ipv6Network}; use phoenix_channel::get_user_agent; @@ -198,13 +196,11 @@ impl WrappedSession { let secret = SecretString::from(token); let device_info = serde_json::from_str(&device_info).unwrap(); - let (private_key, public_key) = keypair(); let url = LoginUrl::client( api_url.as_str(), &secret, device_id, device_name_override, - public_key.to_bytes(), device_info, )?; @@ -215,15 +211,7 @@ impl WrappedSession { .build()?; let _guard = runtime.enter(); // Constructing `PhoenixChannel` requires a runtime context. - let args = ConnectArgs { - private_key, - callbacks: CallbackHandler { - inner: Arc::new(callback_handler), - }, - tcp_socket_factory: Arc::new(socket_factory::tcp), - udp_socket_factory: Arc::new(socket_factory::udp), - }; - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(os_version_override, env!("CARGO_PKG_VERSION")), "client", @@ -233,7 +221,15 @@ impl WrappedSession { .build(), Arc::new(socket_factory::tcp), )?; - let session = Session::connect(args, portal, runtime.handle().clone()); + let session = Session::connect( + Arc::new(socket_factory::tcp), + Arc::new(socket_factory::udp), + CallbackHandler { + inner: Arc::new(callback_handler), + }, + portal, + runtime.handle().clone(), + ); session.set_tun(Box::new(Tun::new()?)); Ok(Self { diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 7ce627312..36a55b266 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -3,7 +3,7 @@ use anyhow::Result; use connlib_model::ResourceId; use firezone_tunnel::messages::{client::*, *}; use firezone_tunnel::ClientTunnel; -use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; +use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam}; use std::{ collections::{BTreeMap, BTreeSet}, io, @@ -16,7 +16,7 @@ pub struct Eventloop { tunnel: ClientTunnel, callbacks: C, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: tokio::sync::mpsc::UnboundedReceiver, connection_intents: SentConnectionIntents, @@ -35,9 +35,11 @@ impl Eventloop { pub(crate) fn new( tunnel: ClientTunnel, callbacks: C, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + mut portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: tokio::sync::mpsc::UnboundedReceiver, ) -> Self { + portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); + Self { tunnel, portal, @@ -70,8 +72,9 @@ where continue; } Poll::Ready(Some(Command::Reset)) => { - self.portal.reconnect(); self.tunnel.reset(); + self.portal + .connect(PublicKeyParam(self.tunnel.public_key().to_bytes())); continue; } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 892adcd70..0630f05f5 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -3,7 +3,6 @@ pub use crate::serde_routelist::{V4RouteList, V6RouteList}; pub use callbacks::{Callbacks, DisconnectError}; pub use connlib_model::StaticSecret; pub use eventloop::Eventloop; -pub use firezone_tunnel::keypair; pub use firezone_tunnel::messages::client::{ ResourceDescription, {IngressMessages, ReplyMessages}, }; @@ -12,7 +11,7 @@ use connlib_model::ResourceId; use eventloop::Command; use firezone_telemetry as telemetry; use firezone_tunnel::ClientTunnel; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; @@ -35,27 +34,26 @@ pub struct Session { channel: tokio::sync::mpsc::UnboundedSender, } -/// Arguments for `connect`, since Clippy said 8 args is too many -pub struct ConnectArgs { - pub tcp_socket_factory: Arc>, - pub udp_socket_factory: Arc>, - pub private_key: StaticSecret, - pub callbacks: CB, -} - impl Session { /// Creates a new [`Session`]. /// /// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key. pub fn connect( - args: ConnectArgs, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + callbacks: CB, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, handle: tokio::runtime::Handle, ) -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let callbacks = args.callbacks.clone(); - let connect_handle = handle.spawn(connect(args, portal, rx)); + let connect_handle = handle.spawn(connect( + tcp_socket_factory, + udp_socket_factory, + callbacks.clone(), + portal, + rx, + )); handle.spawn(connect_supervisor(connect_handle, callbacks)); Self { channel: tx } @@ -118,22 +116,16 @@ impl Session { /// /// When this function exits, the tunnel failed unrecoverably and you need to call it again. async fn connect( - args: ConnectArgs, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + callbacks: CB, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: UnboundedReceiver, ) -> Result<(), DisconnectError> where CB: Callbacks + 'static, { - let ConnectArgs { - private_key, - callbacks, - udp_socket_factory, - tcp_socket_factory, - } = args; - let tunnel = ClientTunnel::new( - private_key, tcp_socket_factory, udp_socket_factory, BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 869cb897d..533ab753b 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -123,10 +123,13 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { + pub fn new(seed: [u8; 32]) -> Self { + let mut rng = StdRng::from_seed(seed); + let private_key = StaticSecret::random_from_rng(&mut rng); let public_key = &(&private_key).into(); + Self { - rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. + rng, session_id: SessionId::new(*public_key), private_key, public_key: *public_key, @@ -174,6 +177,11 @@ where self.connections.clear(); self.buffered_transmits.clear(); + self.private_key = StaticSecret::random_from_rng(&mut self.rng); + self.public_key = (&self.private_key).into(); + self.rate_limiter = Arc::new(RateLimiter::new(&self.public_key, HANDSHAKE_RATE_LIMIT)); + self.session_id = SessionId::new(self.public_key); + tracing::debug!(%num_connections, "Closed all connections as part of reconnecting"); } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 343ad6a3b..d3e2dfdbc 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,4 +1,3 @@ -use boringtun::x25519::StaticSecret; use snownet::{Answer, ClientNode, Event, ServerNode}; use std::{ iter, @@ -73,16 +72,10 @@ fn answer_after_stale_connection_does_not_panic() { fn only_generate_candidate_event_after_answer() { let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000); - let mut alice = ClientNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let mut alice = ClientNode::::new(rand::random()); alice.add_local_host_candidate(local_candidate).unwrap(); - let mut bob = ServerNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let mut bob = ServerNode::::new(rand::random()); let offer = alice.new_connection(1, Instant::now(), Instant::now()); @@ -106,14 +99,8 @@ fn only_generate_candidate_event_after_answer() { } fn alice_and_bob() -> (ClientNode, ServerNode) { - let alice = ClientNode::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); - let bob = ServerNode::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let alice = ClientNode::new(rand::random()); + let bob = ServerNode::new(rand::random()); (alice, bob) } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 2184c75b1..a3fa5b63a 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -14,8 +14,8 @@ use crate::peer_store::PeerStore; use crate::{dns, TunConfig}; use anyhow::Context; use bimap::BiMap; +use connlib_model::PublicKey; use connlib_model::{GatewayId, RelayId, ResourceId, ResourceStatus, ResourceView}; -use connlib_model::{PublicKey, StaticSecret}; use connlib_model::{Site, SiteId}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; @@ -279,11 +279,7 @@ struct AwaitingConnectionDetails { } impl ClientState { - pub(crate) fn new( - private_key: impl Into, - known_hosts: BTreeMap>, - seed: [u8; 32], - ) -> Self { + pub(crate) fn new(known_hosts: BTreeMap>, seed: [u8; 32]) -> Self { Self { awaiting_connection_details: Default::default(), resources_gateways: Default::default(), @@ -294,7 +290,7 @@ impl ClientState { buffered_events: Default::default(), tun_config: Default::default(), buffered_packets: Default::default(), - node: ClientNode::new(private_key.into(), seed), + node: ClientNode::new(seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), @@ -374,7 +370,6 @@ impl ClientState { } } - #[cfg(all(feature = "proptest", test))] pub(crate) fn public_key(&self) -> PublicKey { self.node.public_key() } @@ -1523,7 +1518,6 @@ impl IpProvider { #[cfg(test)] mod tests { use super::*; - use rand::rngs::OsRng; #[test] fn ignores_ip4_igmp_multicast() { @@ -1568,11 +1562,7 @@ mod tests { impl ClientState { pub fn for_test() -> ClientState { - ClientState::new( - StaticSecret::random_from_rng(OsRng), - BTreeMap::new(), - rand::random(), - ) + ClientState::new(BTreeMap::new(), rand::random()) } } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 6dee955c3..bbc904c6d 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -7,7 +7,7 @@ use crate::{GatewayEvent, GatewayTunnel}; use anyhow::Context; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; -use connlib_model::{ClientId, DomainName, RelayId, ResourceId, StaticSecret}; +use connlib_model::{ClientId, DomainName, RelayId, ResourceId}; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::IpPacket; use secrecy::{ExposeSecret as _, Secret}; @@ -163,16 +163,15 @@ impl DnsResourceNatEntry { } impl GatewayState { - pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { + pub(crate) fn new(seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into(), seed), + node: ServerNode::new(seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } } - #[cfg(all(feature = "proptest", test))] pub(crate) fn public_key(&self) -> PublicKey { self.node.public_key() } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 7004c6f00..5e431c4ec 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -5,7 +5,6 @@ use crate::messages::{Offer, Relay, ResolveRequest, SecretKey}; use bimap::BiMap; -use boringtun::x25519::StaticSecret; use chrono::Utc; use connlib_model::{ ClientId, DomainName, GatewayId, PublicKey, RelayId, ResourceId, ResourceView, @@ -13,7 +12,6 @@ use connlib_model::{ use io::Io; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::MAX_DATAGRAM_PAYLOAD; -use rand::rngs::OsRng; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::{BTreeMap, BTreeSet}, @@ -81,20 +79,23 @@ pub struct Tunnel { impl ClientTunnel { pub fn new( - private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, known_hosts: BTreeMap>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: ClientState::new(private_key, known_hosts, rand::random()), + role_state: ClientState::new(known_hosts, rand::random()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } + pub fn public_key(&self) -> PublicKey { + self.role_state.public_key() + } + pub fn reset(&mut self) { self.role_state.reset(); self.io.rebind_sockets(); @@ -177,19 +178,22 @@ impl ClientTunnel { impl GatewayTunnel { pub fn new( - private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: GatewayState::new(private_key, rand::random()), + role_state: GatewayState::new(rand::random()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } + pub fn public_key(&self) -> PublicKey { + self.role_state.public_key() + } + pub fn update_relays(&mut self, to_remove: BTreeSet, to_add: Vec) { self.role_state .update_relays(to_remove, turn(&to_add), Instant::now()) @@ -341,13 +345,6 @@ pub enum GatewayEvent { }, } -pub fn keypair() -> (StaticSecret, PublicKey) { - let private_key = StaticSecret::random_from_rng(OsRng); - let public_key = PublicKey::from(&private_key); - - (private_key, public_key) -} - fn fmt_routes(routes: &BTreeSet, f: &mut fmt::Formatter) -> fmt::Result where T: fmt::Display, diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 06ea6a17f..d9af5b2c6 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -2,9 +2,9 @@ use super::{ composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; -use crate::{client, DomainName, StaticSecret}; +use crate::{client, DomainName}; use crate::{dns::is_subdomain, proptest::relay_id}; -use connlib_model::{GatewayId, RelayId, ResourceId}; +use connlib_model::{GatewayId, RelayId, ResourceId, StaticSecret}; use domain::base::Rtype; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 5f33a751e..752171fde 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -331,7 +331,7 @@ impl RefClient { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimClient { - let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. + let mut client_state = ClientState::new(self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index fc6005ac9..a845a0b25 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -143,7 +143,7 @@ impl RefGateway { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self, id: GatewayId) -> SimGateway { - SimGateway::new(id, GatewayState::new(self.key, self.key.0)) // Cheating a bit here by reusing the key as seed. + SimGateway::new(id, GatewayState::new(self.key.0)) // Cheating a bit here by reusing the key as seed. } } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 7b281057b..fe1e0acd4 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -14,7 +14,7 @@ use firezone_tunnel::messages::{ use firezone_tunnel::{DnsResourceNatEntry, GatewayTunnel}; use futures::channel::mpsc; use futures_bounded::Timeout; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use std::collections::BTreeSet; use std::convert::Infallible; use std::net::IpAddr; @@ -41,7 +41,7 @@ enum ResolveTrigger { pub struct Eventloop { tunnel: GatewayTunnel, - portal: PhoenixChannel<(), IngressMessages, ()>, + portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, tun_device_channel: mpsc::Sender, resolve_tasks: futures_bounded::FuturesTupleSet, ResolveTrigger>, @@ -50,9 +50,11 @@ pub struct Eventloop { impl Eventloop { pub(crate) fn new( tunnel: GatewayTunnel, - portal: PhoenixChannel<(), IngressMessages, ()>, + mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, tun_device_channel: mpsc::Sender, ) -> Self { + portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); + Self { tunnel, portal, diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 3ca4fff5d..1a135f90d 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -2,20 +2,19 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use anyhow::{Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_model::StaticSecret; use firezone_bin_shared::{ http_health_check, linux::{tcp_socket_factory, udp_socket_factory}, TunDeviceManager, }; use firezone_tunnel::messages::Interface; -use firezone_tunnel::{keypair, GatewayTunnel, IPV4_PEERS, IPV6_PEERS}; +use firezone_tunnel::{GatewayTunnel, IPV4_PEERS, IPV6_PEERS}; use phoenix_channel::get_user_agent; use phoenix_channel::LoginUrl; use futures::channel::mpsc; use futures::{future, StreamExt, TryFutureExt}; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use secrecy::{Secret, SecretString}; use std::convert::Infallible; use std::path::Path; @@ -55,16 +54,14 @@ async fn try_main() -> Result<()> { let firezone_id = get_firezone_id(cli.firezone_id).await .context("Couldn't read FIREZONE_ID or write it to disk: Please provide it through the env variable or provide rw access to /var/lib/firezone/")?; - let (private_key, public_key) = keypair(); let login = LoginUrl::gateway( cli.api_url, &SecretString::new(cli.token), firezone_id, cli.firezone_name, - public_key.to_bytes(), )?; - let task = tokio::spawn(run(login, private_key)).err_into(); + let task = tokio::spawn(run(login)).err_into(); let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new)); @@ -107,13 +104,9 @@ async fn get_firezone_id(env_id: Option) -> Result { Ok(id) } -async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { - let mut tunnel = GatewayTunnel::new( - private_key, - Arc::new(tcp_socket_factory), - Arc::new(udp_socket_factory), - ); - let portal = PhoenixChannel::connect( +async fn run(login: LoginUrl) -> Result { + let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory)); + let portal = PhoenixChannel::disconnected( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")), PHOENIX_TOPIC, diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index 962927d79..968e5ba86 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -4,7 +4,6 @@ use crate::{ }; use anyhow::{bail, Context as _, Result}; use clap::Parser; -use connlib_client_shared::{keypair, ConnectArgs}; use connlib_model::ResourceView; use firezone_bin_shared::{ platform::{tcp_socket_factory, udp_socket_factory, DnsControlMethod}, @@ -529,14 +528,12 @@ impl<'a> Handler<'a> { let transaction = firezone_telemetry::start_transaction(ctx); assert!(self.session.is_none()); let device_id = device_id::get_or_create().map_err(|e| Error::DeviceId(e.to_string()))?; - let (private_key, public_key) = keypair(); let url = LoginUrl::client( Url::parse(api_url).map_err(|e| Error::UrlParse(e.to_string()))?, &token, device_id.id, None, - public_key.to_bytes(), device_id::device_info(), ) .map_err(|e| Error::LoginUrl(e.to_string()))?; @@ -544,16 +541,10 @@ impl<'a> Handler<'a> { self.last_connlib_start_instant = Some(Instant::now()); let (cb_tx, cb_rx) = mpsc::channel(1_000); let callbacks = CallbackHandler { cb_tx }; - let args = ConnectArgs { - tcp_socket_factory: Arc::new(tcp_socket_factory), - udp_socket_factory: Arc::new(udp_socket_factory), - private_key, - callbacks, - }; // Synchronous DNS resolution here let phoenix_span = transaction.start_child("phoenix", "Resolve DNS for PhoenixChannel"); - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(None, env!("CARGO_PKG_VERSION")), "client", @@ -569,7 +560,9 @@ impl<'a> Handler<'a> { // Read the resolvers before starting connlib, in case connlib's startup interferes. let dns = self.dns_controller.system_resolvers(); let connlib = connlib_client_shared::Session::connect( - args, + Arc::new(tcp_socket_factory), + Arc::new(udp_socket_factory), + callbacks, portal, tokio::runtime::Handle::current(), ); diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 2d1d6f8be..259fb9e60 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Context as _, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_client_shared::{keypair, ConnectArgs, Session}; +use connlib_client_shared::Session; use firezone_bin_shared::{ new_dns_notifier, new_network_notifier, platform::{tcp_socket_factory, udp_socket_factory}, @@ -177,13 +177,11 @@ fn main() -> Result<()> { }; firezone_telemetry::configure_scope(|scope| scope.set_tag("firezone_id", &firezone_id)); - let (private_key, public_key) = keypair(); let url = LoginUrl::client( cli.api_url, &token, firezone_id, cli.firezone_name, - public_key.to_bytes(), device_id::device_info(), )?; @@ -197,12 +195,6 @@ fn main() -> Result<()> { // The name matches that in `ipc_service.rs` let mut last_connlib_start_instant = Some(Instant::now()); - let args = ConnectArgs { - udp_socket_factory: Arc::new(udp_socket_factory), - tcp_socket_factory: Arc::new(tcp_socket_factory), - private_key, - callbacks, - }; let result = rt.block_on(async { let ctx = firezone_telemetry::TransactionContext::new( @@ -217,7 +209,7 @@ fn main() -> Result<()> { // for an Internet connection if it launches us at startup. // When running interactively, it is useful for the user to see that we can't reach the portal. let phoenix_span = transaction.start_child("phoenix", "Connect PhoenixChannel"); - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(None, env!("CARGO_PKG_VERSION")), "client", @@ -228,7 +220,13 @@ fn main() -> Result<()> { Arc::new(tcp_socket_factory), )?; phoenix_span.finish(); - let session = Session::connect(args, portal, rt.handle().clone()); + let session = Session::connect( + Arc::new(tcp_socket_factory), + Arc::new(udp_socket_factory), + callbacks, + portal, + rt.handle().clone(), + ); let mut terminate = signals::Terminate::new()?; let mut hangup = signals::Hangup::new()?; diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 3f531b926..ca7d42726 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -3,7 +3,7 @@ mod heartbeat; mod login_url; use std::collections::{HashSet, VecDeque}; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs as _}; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::{fmt, future, marker::PhantomData}; @@ -29,9 +29,9 @@ use tokio_tungstenite::{ use url::{Host, Url}; pub use get_user_agent::get_user_agent; -pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError}; +pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam}; -pub struct PhoenixChannel { +pub struct PhoenixChannel { state: State, waker: Option, pending_messages: VecDeque, @@ -45,7 +45,8 @@ pub struct PhoenixChannel { pending_join_requests: HashSet, // Stored here to allow re-connecting. - url: Secret, + url_prototype: Secret>, + last_url: Option, user_agent: String, reconnect_backoff: ExponentialBackoff, @@ -66,7 +67,7 @@ enum State { impl State { fn connect( - url: Secret, + url: Url, user_agent: String, socket_factory: Arc>, ) -> Self { @@ -75,11 +76,13 @@ impl State { } async fn create_and_connect_websocket( - url: Secret, + url: Url, user_agent: String, socket_factory: Arc>, ) -> Result>, InternalError> { - let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?; + tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal"); + + let socket = make_socket(&url, &*socket_factory).await?; let (stream, _) = client_async_tls(make_request(url, user_agent), socket) .await @@ -214,18 +217,22 @@ impl fmt::Display for OutboundRequestId { #[error("Cannot close websocket while we are connecting")] pub struct Connecting; -impl PhoenixChannel +impl + PhoenixChannel where TInitReq: Serialize + Clone, TInboundMsg: DeserializeOwned, TOutboundRes: DeserializeOwned, + TFinish: IntoIterator, { - /// Creates a new [PhoenixChannel] to the given endpoint. + /// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state. + /// + /// You must explicitly call [`PhoenixChannel::connect`] to establish a connection. /// /// The provided URL must contain a host. /// Additionally, you must already provide any query parameters required for authentication. - pub fn connect( - url: Secret, + pub fn disconnected( + url: Secret>, user_agent: String, login: &'static str, init_req: TInitReq, @@ -239,19 +246,16 @@ where // We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs. let resolved_addresses = url .expose_secret() - .inner() - .socket_addrs(|| None)? - .iter() + .host() + .to_socket_addrs()? .map(|addr| addr.ip()) .collect(); - tracing::debug!(host = %url.expose_secret().host(), %user_agent, "Connecting to portal"); - Ok(Self { reconnect_backoff, - url: url.clone(), - user_agent: user_agent.clone(), - state: State::connect(url, user_agent, socket_factory.clone()), + url_prototype: url, + user_agent, + state: State::Closed, socket_factory, waker: None, pending_messages: Default::default(), @@ -266,6 +270,7 @@ where login, init_req, resolved_addresses, + last_url: None, }) } @@ -276,7 +281,7 @@ where /// The host we are connecting / connected to. pub fn server_host(&self) -> &str { - self.url.expose_secret().host() + self.url_prototype.expose_secret().host() } /// Join the provided room. @@ -297,15 +302,16 @@ where id } - /// Reconnects to the portal. - pub fn reconnect(&mut self) { + /// Establishes a new connection, dropping the current one if any exists. + pub fn connect(&mut self, params: TFinish) { // 1. Reset the backoff. self.reconnect_backoff.reset(); // 2. Set state to `Connecting` without a timer. - let url = self.url.clone(); + let url = self.url_prototype.expose_secret().to_url(params); let user_agent = self.user_agent.clone(); - self.state = State::connect(url, user_agent, self.socket_factory.clone()); + self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone()); + self.last_url = Some(url); // 3. In case we were already re-connecting, we need to wake the suspended task. if let Some(waker) = self.waker.take() { @@ -358,7 +364,7 @@ where self.heartbeat.reset(); self.state = State::Connected(stream); - let host = self.url.expose_secret().host(); + let host = self.url_prototype.expose_secret().host(); tracing::info!(%host, "Connected to portal"); self.join(self.login, self.init_req.clone()); @@ -376,7 +382,11 @@ where return Poll::Ready(Err(Error::MaxRetriesReached)); }; - let secret_url = self.url.clone(); + let secret_url = self + .last_url + .as_ref() + .expect("should have last URL if we receive connection error") + .clone(); let user_agent = self.user_agent.clone(); let socket_factory = self.socket_factory.clone(); @@ -736,22 +746,20 @@ impl PhoenixMessage { } // This is basically the same as tungstenite does but we add some new headers (namely user-agent) -fn make_request(url: Secret, user_agent: String) -> Request { - use secrecy::ExposeSecret as _; - +fn make_request(url: Url, user_agent: String) -> Request { let mut r = [0u8; 16]; OsRng.fill_bytes(&mut r); let key = base64::engine::general_purpose::STANDARD.encode(r); Request::builder() .method("GET") - .header("Host", url.expose_secret().host()) + .header("Host", url.host().unwrap().to_string()) .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", key) .header("User-Agent", user_agent) - .uri(url.expose_secret().inner().as_str()) + .uri(url.to_string()) .body(()) .expect("building static request always works") } diff --git a/rust/phoenix-channel/src/login_url.rs b/rust/phoenix-channel/src/login_url.rs index a47d9754b..2c2a1e7a0 100644 --- a/rust/phoenix-channel/src/login_url.rs +++ b/rust/phoenix-channel/src/login_url.rs @@ -2,7 +2,11 @@ use base64::{engine::general_purpose::STANDARD, Engine}; use secrecy::{CloneableSecret, ExposeSecret as _, SecretString, Zeroize}; use serde::Deserialize; use sha2::Digest as _; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::{ + iter, + marker::PhantomData, + net::{Ipv4Addr, Ipv6Addr}, +}; use url::Url; use uuid::Uuid; @@ -27,16 +31,18 @@ pub struct DeviceInfo { } #[derive(Clone)] -pub struct LoginUrl { +pub struct LoginUrl { url: Url, // Invariant: Must stay the same as the host in `url`. // This is duplicated here because `Url::host` is fallible. // If we don't duplicate it, we'd have to do extra error handling in several places instead of just one place. host: String, + + phantom: PhantomData, } -impl Zeroize for LoginUrl { +impl Zeroize for LoginUrl { fn zeroize(&mut self) { let placeholder = Url::parse("http://a.com") .expect("placeholder URL should always be valid, it's hard-coded"); @@ -44,15 +50,36 @@ impl Zeroize for LoginUrl { } } -impl CloneableSecret for LoginUrl {} +pub struct PublicKeyParam(pub [u8; 32]); -impl LoginUrl { +impl IntoIterator for PublicKeyParam { + type Item = (&'static str, String); + type IntoIter = std::iter::Once; + + fn into_iter(self) -> Self::IntoIter { + iter::once(("public_key", STANDARD.encode(self.0))) + } +} + +pub struct NoParams; + +impl IntoIterator for NoParams { + type Item = (&'static str, String); + type IntoIter = std::iter::Empty; + + fn into_iter(self) -> Self::IntoIter { + std::iter::empty() + } +} + +impl CloneableSecret for LoginUrl where TFinish: Clone {} + +impl LoginUrl { pub fn client( url: impl TryInto, firezone_token: &SecretString, device_id: String, device_name: Option, - public_key: [u8; 32], device_info: DeviceInfo, ) -> Result> { let external_id = hex::encode(sha2::Sha256::digest(device_id)); @@ -64,7 +91,6 @@ impl LoginUrl { url.try_into().map_err(LoginUrlError::InvalidUrl)?, firezone_token, "client", - Some(public_key), Some(external_id), Some(device_name), None, @@ -76,6 +102,7 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } @@ -84,7 +111,6 @@ impl LoginUrl { firezone_token: &SecretString, device_id: String, device_name: Option, - public_key: [u8; 32], ) -> Result> { let external_id = hex::encode(sha2::Sha256::digest(device_id)); let device_name = device_name @@ -95,7 +121,6 @@ impl LoginUrl { url.try_into().map_err(LoginUrlError::InvalidUrl)?, firezone_token, "gateway", - Some(public_key), Some(external_id), Some(device_name), None, @@ -107,9 +132,12 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } +} +impl LoginUrl { pub fn relay( url: impl TryInto, firezone_token: &SecretString, @@ -123,7 +151,6 @@ impl LoginUrl { firezone_token, "relay", None, - None, device_name, Some(listen_port), ipv4_address, @@ -134,15 +161,25 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } +} - // TODO: Only temporarily public until we delete other phoenix-channel impl. - pub fn inner(&self) -> &Url { - &self.url +impl LoginUrl +where + TFinish: IntoIterator, +{ + pub fn to_url(&self, params: TFinish) -> Url { + let mut url = self.url.clone(); + + url.query_pairs_mut().extend_pairs(params); + + url } +} - // TODO: Only temporarily public until we delete other phoenix-channel impl. +impl LoginUrl { pub fn host(&self) -> &str { &self.host } @@ -190,7 +227,6 @@ fn get_websocket_path( mut api_url: Url, token: &SecretString, mode: &str, - public_key: Option<[u8; 32]>, external_id: Option, name: Option, port: Option, @@ -215,9 +251,6 @@ fn get_websocket_path( query_pairs.clear(); query_pairs.append_pair("token", token.expose_secret()); - if let Some(public_key) = public_key { - query_pairs.append_pair("public_key", &STANDARD.encode(public_key)); - } if let Some(external_id) = external_id { query_pairs.append_pair("external_id", &external_id); } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 148dd569c..3dd4c3f08 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -8,7 +8,7 @@ use firezone_relay::{ PeerSocket, Server, Sleep, }; use futures::{future, FutureExt}; -use phoenix_channel::{Event, LoginUrl, PhoenixChannel}; +use phoenix_channel::{Event, LoginUrl, NoParams, PhoenixChannel}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; @@ -139,7 +139,7 @@ async fn main() -> Result<()> { args.public_ip6_addr, )?; - Some(PhoenixChannel::connect( + let mut channel = PhoenixChannel::disconnected( Secret::new(login), format!("relay/{}", env!("CARGO_PKG_VERSION")), "relay", @@ -150,7 +150,10 @@ async fn main() -> Result<()> { .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) .build(), Arc::new(socket_factory::tcp), - )?) + )?; + channel.connect(NoParams); + + Some(channel) } else { tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode"); @@ -305,7 +308,7 @@ struct Eventloop { sockets: Sockets, server: Server, - channel: Option>, + channel: Option>, sleep: Sleep, sigterm: unix::Signal, @@ -325,7 +328,7 @@ where { fn new( server: Server, - channel: Option>, + channel: Option>, public_address: IpStack, last_heartbeat_sent: Arc>>, ) -> Result {