diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 7ec6b734e..165740c42 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -5,9 +5,7 @@ use crate::tun::Tun; use backoff::ExponentialBackoffBuilder; -use connlib_client_shared::{ - keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList, -}; +use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; use connlib_model::{ResourceId, ResourceView}; use ip_network::{Ipv4Network, Ipv6Network}; use jni::{ @@ -355,13 +353,11 @@ fn connect( handle, }; - let (private_key, public_key) = keypair(); let url = LoginUrl::client( api_url.as_str(), &secret, device_id, Some(device_name), - public_key.to_bytes(), device_info, )?; @@ -374,13 +370,7 @@ fn connect( let tcp_socket_factory = Arc::new(protected_tcp_socket_factory(callbacks.clone())); - let args = ConnectArgs { - tcp_socket_factory: tcp_socket_factory.clone(), - udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())), - private_key, - callbacks, - }; - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(Some(os_version), env!("CARGO_PKG_VERSION")), "client", @@ -390,7 +380,13 @@ fn connect( .build(), tcp_socket_factory, )?; - let session = Session::connect(args, portal, runtime.handle().clone()); + let session = Session::connect( + Arc::new(protected_tcp_socket_factory(callbacks.clone())), + Arc::new(protected_udp_socket_factory(callbacks.clone())), + callbacks, + portal, + runtime.handle().clone(), + ); Ok(SessionWrapper { inner: session, diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 4e1bcc0b5..3817f1053 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -6,9 +6,7 @@ mod tun; use anyhow::Result; use backoff::ExponentialBackoffBuilder; -use connlib_client_shared::{ - keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList, -}; +use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList}; use connlib_model::ResourceView; use ip_network::{Ipv4Network, Ipv6Network}; use phoenix_channel::get_user_agent; @@ -198,13 +196,11 @@ impl WrappedSession { let secret = SecretString::from(token); let device_info = serde_json::from_str(&device_info).unwrap(); - let (private_key, public_key) = keypair(); let url = LoginUrl::client( api_url.as_str(), &secret, device_id, device_name_override, - public_key.to_bytes(), device_info, )?; @@ -215,15 +211,7 @@ impl WrappedSession { .build()?; let _guard = runtime.enter(); // Constructing `PhoenixChannel` requires a runtime context. - let args = ConnectArgs { - private_key, - callbacks: CallbackHandler { - inner: Arc::new(callback_handler), - }, - tcp_socket_factory: Arc::new(socket_factory::tcp), - udp_socket_factory: Arc::new(socket_factory::udp), - }; - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(os_version_override, env!("CARGO_PKG_VERSION")), "client", @@ -233,7 +221,15 @@ impl WrappedSession { .build(), Arc::new(socket_factory::tcp), )?; - let session = Session::connect(args, portal, runtime.handle().clone()); + let session = Session::connect( + Arc::new(socket_factory::tcp), + Arc::new(socket_factory::udp), + CallbackHandler { + inner: Arc::new(callback_handler), + }, + portal, + runtime.handle().clone(), + ); session.set_tun(Box::new(Tun::new()?)); Ok(Self { diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 7ce627312..36a55b266 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -3,7 +3,7 @@ use anyhow::Result; use connlib_model::ResourceId; use firezone_tunnel::messages::{client::*, *}; use firezone_tunnel::ClientTunnel; -use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; +use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam}; use std::{ collections::{BTreeMap, BTreeSet}, io, @@ -16,7 +16,7 @@ pub struct Eventloop { tunnel: ClientTunnel, callbacks: C, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: tokio::sync::mpsc::UnboundedReceiver, connection_intents: SentConnectionIntents, @@ -35,9 +35,11 @@ impl Eventloop { pub(crate) fn new( tunnel: ClientTunnel, callbacks: C, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + mut portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: tokio::sync::mpsc::UnboundedReceiver, ) -> Self { + portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); + Self { tunnel, portal, @@ -70,8 +72,9 @@ where continue; } Poll::Ready(Some(Command::Reset)) => { - self.portal.reconnect(); self.tunnel.reset(); + self.portal + .connect(PublicKeyParam(self.tunnel.public_key().to_bytes())); continue; } diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 892adcd70..0630f05f5 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -3,7 +3,6 @@ pub use crate::serde_routelist::{V4RouteList, V6RouteList}; pub use callbacks::{Callbacks, DisconnectError}; pub use connlib_model::StaticSecret; pub use eventloop::Eventloop; -pub use firezone_tunnel::keypair; pub use firezone_tunnel::messages::client::{ ResourceDescription, {IngressMessages, ReplyMessages}, }; @@ -12,7 +11,7 @@ use connlib_model::ResourceId; use eventloop::Command; use firezone_telemetry as telemetry; use firezone_tunnel::ClientTunnel; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; @@ -35,27 +34,26 @@ pub struct Session { channel: tokio::sync::mpsc::UnboundedSender, } -/// Arguments for `connect`, since Clippy said 8 args is too many -pub struct ConnectArgs { - pub tcp_socket_factory: Arc>, - pub udp_socket_factory: Arc>, - pub private_key: StaticSecret, - pub callbacks: CB, -} - impl Session { /// Creates a new [`Session`]. /// /// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key. pub fn connect( - args: ConnectArgs, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + callbacks: CB, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, handle: tokio::runtime::Handle, ) -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let callbacks = args.callbacks.clone(); - let connect_handle = handle.spawn(connect(args, portal, rx)); + let connect_handle = handle.spawn(connect( + tcp_socket_factory, + udp_socket_factory, + callbacks.clone(), + portal, + rx, + )); handle.spawn(connect_supervisor(connect_handle, callbacks)); Self { channel: tx } @@ -118,22 +116,16 @@ impl Session { /// /// When this function exits, the tunnel failed unrecoverably and you need to call it again. async fn connect( - args: ConnectArgs, - portal: PhoenixChannel<(), IngressMessages, ReplyMessages>, + tcp_socket_factory: Arc>, + udp_socket_factory: Arc>, + callbacks: CB, + portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>, rx: UnboundedReceiver, ) -> Result<(), DisconnectError> where CB: Callbacks + 'static, { - let ConnectArgs { - private_key, - callbacks, - udp_socket_factory, - tcp_socket_factory, - } = args; - let tunnel = ClientTunnel::new( - private_key, tcp_socket_factory, udp_socket_factory, BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]), diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 869cb897d..533ab753b 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -123,10 +123,13 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { + pub fn new(seed: [u8; 32]) -> Self { + let mut rng = StdRng::from_seed(seed); + let private_key = StaticSecret::random_from_rng(&mut rng); let public_key = &(&private_key).into(); + Self { - rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. + rng, session_id: SessionId::new(*public_key), private_key, public_key: *public_key, @@ -174,6 +177,11 @@ where self.connections.clear(); self.buffered_transmits.clear(); + self.private_key = StaticSecret::random_from_rng(&mut self.rng); + self.public_key = (&self.private_key).into(); + self.rate_limiter = Arc::new(RateLimiter::new(&self.public_key, HANDSHAKE_RATE_LIMIT)); + self.session_id = SessionId::new(self.public_key); + tracing::debug!(%num_connections, "Closed all connections as part of reconnecting"); } diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 343ad6a3b..d3e2dfdbc 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -1,4 +1,3 @@ -use boringtun::x25519::StaticSecret; use snownet::{Answer, ClientNode, Event, ServerNode}; use std::{ iter, @@ -73,16 +72,10 @@ fn answer_after_stale_connection_does_not_panic() { fn only_generate_candidate_event_after_answer() { let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000); - let mut alice = ClientNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let mut alice = ClientNode::::new(rand::random()); alice.add_local_host_candidate(local_candidate).unwrap(); - let mut bob = ServerNode::::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let mut bob = ServerNode::::new(rand::random()); let offer = alice.new_connection(1, Instant::now(), Instant::now()); @@ -106,14 +99,8 @@ fn only_generate_candidate_event_after_answer() { } fn alice_and_bob() -> (ClientNode, ServerNode) { - let alice = ClientNode::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); - let bob = ServerNode::new( - StaticSecret::random_from_rng(rand::thread_rng()), - rand::random(), - ); + let alice = ClientNode::new(rand::random()); + let bob = ServerNode::new(rand::random()); (alice, bob) } diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 2184c75b1..a3fa5b63a 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -14,8 +14,8 @@ use crate::peer_store::PeerStore; use crate::{dns, TunConfig}; use anyhow::Context; use bimap::BiMap; +use connlib_model::PublicKey; use connlib_model::{GatewayId, RelayId, ResourceId, ResourceStatus, ResourceView}; -use connlib_model::{PublicKey, StaticSecret}; use connlib_model::{Site, SiteId}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; @@ -279,11 +279,7 @@ struct AwaitingConnectionDetails { } impl ClientState { - pub(crate) fn new( - private_key: impl Into, - known_hosts: BTreeMap>, - seed: [u8; 32], - ) -> Self { + pub(crate) fn new(known_hosts: BTreeMap>, seed: [u8; 32]) -> Self { Self { awaiting_connection_details: Default::default(), resources_gateways: Default::default(), @@ -294,7 +290,7 @@ impl ClientState { buffered_events: Default::default(), tun_config: Default::default(), buffered_packets: Default::default(), - node: ClientNode::new(private_key.into(), seed), + node: ClientNode::new(seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), @@ -374,7 +370,6 @@ impl ClientState { } } - #[cfg(all(feature = "proptest", test))] pub(crate) fn public_key(&self) -> PublicKey { self.node.public_key() } @@ -1523,7 +1518,6 @@ impl IpProvider { #[cfg(test)] mod tests { use super::*; - use rand::rngs::OsRng; #[test] fn ignores_ip4_igmp_multicast() { @@ -1568,11 +1562,7 @@ mod tests { impl ClientState { pub fn for_test() -> ClientState { - ClientState::new( - StaticSecret::random_from_rng(OsRng), - BTreeMap::new(), - rand::random(), - ) + ClientState::new(BTreeMap::new(), rand::random()) } } diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 6dee955c3..bbc904c6d 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -7,7 +7,7 @@ use crate::{GatewayEvent, GatewayTunnel}; use anyhow::Context; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; -use connlib_model::{ClientId, DomainName, RelayId, ResourceId, StaticSecret}; +use connlib_model::{ClientId, DomainName, RelayId, ResourceId}; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::IpPacket; use secrecy::{ExposeSecret as _, Secret}; @@ -163,16 +163,15 @@ impl DnsResourceNatEntry { } impl GatewayState { - pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { + pub(crate) fn new(seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into(), seed), + node: ServerNode::new(seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } } - #[cfg(all(feature = "proptest", test))] pub(crate) fn public_key(&self) -> PublicKey { self.node.public_key() } diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 7004c6f00..5e431c4ec 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -5,7 +5,6 @@ use crate::messages::{Offer, Relay, ResolveRequest, SecretKey}; use bimap::BiMap; -use boringtun::x25519::StaticSecret; use chrono::Utc; use connlib_model::{ ClientId, DomainName, GatewayId, PublicKey, RelayId, ResourceId, ResourceView, @@ -13,7 +12,6 @@ use connlib_model::{ use io::Io; use ip_network::{Ipv4Network, Ipv6Network}; use ip_packet::MAX_DATAGRAM_PAYLOAD; -use rand::rngs::OsRng; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ collections::{BTreeMap, BTreeSet}, @@ -81,20 +79,23 @@ pub struct Tunnel { impl ClientTunnel { pub fn new( - private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, known_hosts: BTreeMap>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: ClientState::new(private_key, known_hosts, rand::random()), + role_state: ClientState::new(known_hosts, rand::random()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } + pub fn public_key(&self) -> PublicKey { + self.role_state.public_key() + } + pub fn reset(&mut self) { self.role_state.reset(); self.io.rebind_sockets(); @@ -177,19 +178,22 @@ impl ClientTunnel { impl GatewayTunnel { pub fn new( - private_key: StaticSecret, tcp_socket_factory: Arc>, udp_socket_factory: Arc>, ) -> Self { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), - role_state: GatewayState::new(private_key, rand::random()), + role_state: GatewayState::new(rand::random()), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } + pub fn public_key(&self) -> PublicKey { + self.role_state.public_key() + } + pub fn update_relays(&mut self, to_remove: BTreeSet, to_add: Vec) { self.role_state .update_relays(to_remove, turn(&to_add), Instant::now()) @@ -341,13 +345,6 @@ pub enum GatewayEvent { }, } -pub fn keypair() -> (StaticSecret, PublicKey) { - let private_key = StaticSecret::random_from_rng(OsRng); - let public_key = PublicKey::from(&private_key); - - (private_key, public_key) -} - fn fmt_routes(routes: &BTreeSet, f: &mut fmt::Formatter) -> fmt::Result where T: fmt::Display, diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 06ea6a17f..d9af5b2c6 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -2,9 +2,9 @@ use super::{ composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*, strategies::*, stub_portal::StubPortal, transition::*, }; -use crate::{client, DomainName, StaticSecret}; +use crate::{client, DomainName}; use crate::{dns::is_subdomain, proptest::relay_id}; -use connlib_model::{GatewayId, RelayId, ResourceId}; +use connlib_model::{GatewayId, RelayId, ResourceId, StaticSecret}; use domain::base::Rtype; use proptest::{prelude::*, sample}; use proptest_state_machine::ReferenceStateMachine; diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 5f33a751e..752171fde 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -331,7 +331,7 @@ impl RefClient { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self) -> SimClient { - let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. + let mut client_state = ClientState::new(self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed. client_state.update_interface_config(Interface { ipv4: self.tunnel_ip4, ipv6: self.tunnel_ip6, diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index fc6005ac9..a845a0b25 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -143,7 +143,7 @@ impl RefGateway { /// /// This simulates receiving the `init` message from the portal. pub(crate) fn init(self, id: GatewayId) -> SimGateway { - SimGateway::new(id, GatewayState::new(self.key, self.key.0)) // Cheating a bit here by reusing the key as seed. + SimGateway::new(id, GatewayState::new(self.key.0)) // Cheating a bit here by reusing the key as seed. } } diff --git a/rust/gateway/src/eventloop.rs b/rust/gateway/src/eventloop.rs index 7b281057b..fe1e0acd4 100644 --- a/rust/gateway/src/eventloop.rs +++ b/rust/gateway/src/eventloop.rs @@ -14,7 +14,7 @@ use firezone_tunnel::messages::{ use firezone_tunnel::{DnsResourceNatEntry, GatewayTunnel}; use futures::channel::mpsc; use futures_bounded::Timeout; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use std::collections::BTreeSet; use std::convert::Infallible; use std::net::IpAddr; @@ -41,7 +41,7 @@ enum ResolveTrigger { pub struct Eventloop { tunnel: GatewayTunnel, - portal: PhoenixChannel<(), IngressMessages, ()>, + portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, tun_device_channel: mpsc::Sender, resolve_tasks: futures_bounded::FuturesTupleSet, ResolveTrigger>, @@ -50,9 +50,11 @@ pub struct Eventloop { impl Eventloop { pub(crate) fn new( tunnel: GatewayTunnel, - portal: PhoenixChannel<(), IngressMessages, ()>, + mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>, tun_device_channel: mpsc::Sender, ) -> Self { + portal.connect(PublicKeyParam(tunnel.public_key().to_bytes())); + Self { tunnel, portal, diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 3ca4fff5d..1a135f90d 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -2,20 +2,19 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use anyhow::{Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_model::StaticSecret; use firezone_bin_shared::{ http_health_check, linux::{tcp_socket_factory, udp_socket_factory}, TunDeviceManager, }; use firezone_tunnel::messages::Interface; -use firezone_tunnel::{keypair, GatewayTunnel, IPV4_PEERS, IPV6_PEERS}; +use firezone_tunnel::{GatewayTunnel, IPV4_PEERS, IPV6_PEERS}; use phoenix_channel::get_user_agent; use phoenix_channel::LoginUrl; use futures::channel::mpsc; use futures::{future, StreamExt, TryFutureExt}; -use phoenix_channel::PhoenixChannel; +use phoenix_channel::{PhoenixChannel, PublicKeyParam}; use secrecy::{Secret, SecretString}; use std::convert::Infallible; use std::path::Path; @@ -55,16 +54,14 @@ async fn try_main() -> Result<()> { let firezone_id = get_firezone_id(cli.firezone_id).await .context("Couldn't read FIREZONE_ID or write it to disk: Please provide it through the env variable or provide rw access to /var/lib/firezone/")?; - let (private_key, public_key) = keypair(); let login = LoginUrl::gateway( cli.api_url, &SecretString::new(cli.token), firezone_id, cli.firezone_name, - public_key.to_bytes(), )?; - let task = tokio::spawn(run(login, private_key)).err_into(); + let task = tokio::spawn(run(login)).err_into(); let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new)); @@ -107,13 +104,9 @@ async fn get_firezone_id(env_id: Option) -> Result { Ok(id) } -async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { - let mut tunnel = GatewayTunnel::new( - private_key, - Arc::new(tcp_socket_factory), - Arc::new(udp_socket_factory), - ); - let portal = PhoenixChannel::connect( +async fn run(login: LoginUrl) -> Result { + let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory)); + let portal = PhoenixChannel::disconnected( Secret::new(login), get_user_agent(None, env!("CARGO_PKG_VERSION")), PHOENIX_TOPIC, diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index 962927d79..968e5ba86 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -4,7 +4,6 @@ use crate::{ }; use anyhow::{bail, Context as _, Result}; use clap::Parser; -use connlib_client_shared::{keypair, ConnectArgs}; use connlib_model::ResourceView; use firezone_bin_shared::{ platform::{tcp_socket_factory, udp_socket_factory, DnsControlMethod}, @@ -529,14 +528,12 @@ impl<'a> Handler<'a> { let transaction = firezone_telemetry::start_transaction(ctx); assert!(self.session.is_none()); let device_id = device_id::get_or_create().map_err(|e| Error::DeviceId(e.to_string()))?; - let (private_key, public_key) = keypair(); let url = LoginUrl::client( Url::parse(api_url).map_err(|e| Error::UrlParse(e.to_string()))?, &token, device_id.id, None, - public_key.to_bytes(), device_id::device_info(), ) .map_err(|e| Error::LoginUrl(e.to_string()))?; @@ -544,16 +541,10 @@ impl<'a> Handler<'a> { self.last_connlib_start_instant = Some(Instant::now()); let (cb_tx, cb_rx) = mpsc::channel(1_000); let callbacks = CallbackHandler { cb_tx }; - let args = ConnectArgs { - tcp_socket_factory: Arc::new(tcp_socket_factory), - udp_socket_factory: Arc::new(udp_socket_factory), - private_key, - callbacks, - }; // Synchronous DNS resolution here let phoenix_span = transaction.start_child("phoenix", "Resolve DNS for PhoenixChannel"); - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(None, env!("CARGO_PKG_VERSION")), "client", @@ -569,7 +560,9 @@ impl<'a> Handler<'a> { // Read the resolvers before starting connlib, in case connlib's startup interferes. let dns = self.dns_controller.system_resolvers(); let connlib = connlib_client_shared::Session::connect( - args, + Arc::new(tcp_socket_factory), + Arc::new(udp_socket_factory), + callbacks, portal, tokio::runtime::Handle::current(), ); diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 2d1d6f8be..259fb9e60 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Context as _, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_client_shared::{keypair, ConnectArgs, Session}; +use connlib_client_shared::Session; use firezone_bin_shared::{ new_dns_notifier, new_network_notifier, platform::{tcp_socket_factory, udp_socket_factory}, @@ -177,13 +177,11 @@ fn main() -> Result<()> { }; firezone_telemetry::configure_scope(|scope| scope.set_tag("firezone_id", &firezone_id)); - let (private_key, public_key) = keypair(); let url = LoginUrl::client( cli.api_url, &token, firezone_id, cli.firezone_name, - public_key.to_bytes(), device_id::device_info(), )?; @@ -197,12 +195,6 @@ fn main() -> Result<()> { // The name matches that in `ipc_service.rs` let mut last_connlib_start_instant = Some(Instant::now()); - let args = ConnectArgs { - udp_socket_factory: Arc::new(udp_socket_factory), - tcp_socket_factory: Arc::new(tcp_socket_factory), - private_key, - callbacks, - }; let result = rt.block_on(async { let ctx = firezone_telemetry::TransactionContext::new( @@ -217,7 +209,7 @@ fn main() -> Result<()> { // for an Internet connection if it launches us at startup. // When running interactively, it is useful for the user to see that we can't reach the portal. let phoenix_span = transaction.start_child("phoenix", "Connect PhoenixChannel"); - let portal = PhoenixChannel::connect( + let portal = PhoenixChannel::disconnected( Secret::new(url), get_user_agent(None, env!("CARGO_PKG_VERSION")), "client", @@ -228,7 +220,13 @@ fn main() -> Result<()> { Arc::new(tcp_socket_factory), )?; phoenix_span.finish(); - let session = Session::connect(args, portal, rt.handle().clone()); + let session = Session::connect( + Arc::new(tcp_socket_factory), + Arc::new(udp_socket_factory), + callbacks, + portal, + rt.handle().clone(), + ); let mut terminate = signals::Terminate::new()?; let mut hangup = signals::Hangup::new()?; diff --git a/rust/phoenix-channel/src/lib.rs b/rust/phoenix-channel/src/lib.rs index 3f531b926..ca7d42726 100644 --- a/rust/phoenix-channel/src/lib.rs +++ b/rust/phoenix-channel/src/lib.rs @@ -3,7 +3,7 @@ mod heartbeat; mod login_url; use std::collections::{HashSet, VecDeque}; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs as _}; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::{fmt, future, marker::PhantomData}; @@ -29,9 +29,9 @@ use tokio_tungstenite::{ use url::{Host, Url}; pub use get_user_agent::get_user_agent; -pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError}; +pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam}; -pub struct PhoenixChannel { +pub struct PhoenixChannel { state: State, waker: Option, pending_messages: VecDeque, @@ -45,7 +45,8 @@ pub struct PhoenixChannel { pending_join_requests: HashSet, // Stored here to allow re-connecting. - url: Secret, + url_prototype: Secret>, + last_url: Option, user_agent: String, reconnect_backoff: ExponentialBackoff, @@ -66,7 +67,7 @@ enum State { impl State { fn connect( - url: Secret, + url: Url, user_agent: String, socket_factory: Arc>, ) -> Self { @@ -75,11 +76,13 @@ impl State { } async fn create_and_connect_websocket( - url: Secret, + url: Url, user_agent: String, socket_factory: Arc>, ) -> Result>, InternalError> { - let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?; + tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal"); + + let socket = make_socket(&url, &*socket_factory).await?; let (stream, _) = client_async_tls(make_request(url, user_agent), socket) .await @@ -214,18 +217,22 @@ impl fmt::Display for OutboundRequestId { #[error("Cannot close websocket while we are connecting")] pub struct Connecting; -impl PhoenixChannel +impl + PhoenixChannel where TInitReq: Serialize + Clone, TInboundMsg: DeserializeOwned, TOutboundRes: DeserializeOwned, + TFinish: IntoIterator, { - /// Creates a new [PhoenixChannel] to the given endpoint. + /// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state. + /// + /// You must explicitly call [`PhoenixChannel::connect`] to establish a connection. /// /// The provided URL must contain a host. /// Additionally, you must already provide any query parameters required for authentication. - pub fn connect( - url: Secret, + pub fn disconnected( + url: Secret>, user_agent: String, login: &'static str, init_req: TInitReq, @@ -239,19 +246,16 @@ where // We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs. let resolved_addresses = url .expose_secret() - .inner() - .socket_addrs(|| None)? - .iter() + .host() + .to_socket_addrs()? .map(|addr| addr.ip()) .collect(); - tracing::debug!(host = %url.expose_secret().host(), %user_agent, "Connecting to portal"); - Ok(Self { reconnect_backoff, - url: url.clone(), - user_agent: user_agent.clone(), - state: State::connect(url, user_agent, socket_factory.clone()), + url_prototype: url, + user_agent, + state: State::Closed, socket_factory, waker: None, pending_messages: Default::default(), @@ -266,6 +270,7 @@ where login, init_req, resolved_addresses, + last_url: None, }) } @@ -276,7 +281,7 @@ where /// The host we are connecting / connected to. pub fn server_host(&self) -> &str { - self.url.expose_secret().host() + self.url_prototype.expose_secret().host() } /// Join the provided room. @@ -297,15 +302,16 @@ where id } - /// Reconnects to the portal. - pub fn reconnect(&mut self) { + /// Establishes a new connection, dropping the current one if any exists. + pub fn connect(&mut self, params: TFinish) { // 1. Reset the backoff. self.reconnect_backoff.reset(); // 2. Set state to `Connecting` without a timer. - let url = self.url.clone(); + let url = self.url_prototype.expose_secret().to_url(params); let user_agent = self.user_agent.clone(); - self.state = State::connect(url, user_agent, self.socket_factory.clone()); + self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone()); + self.last_url = Some(url); // 3. In case we were already re-connecting, we need to wake the suspended task. if let Some(waker) = self.waker.take() { @@ -358,7 +364,7 @@ where self.heartbeat.reset(); self.state = State::Connected(stream); - let host = self.url.expose_secret().host(); + let host = self.url_prototype.expose_secret().host(); tracing::info!(%host, "Connected to portal"); self.join(self.login, self.init_req.clone()); @@ -376,7 +382,11 @@ where return Poll::Ready(Err(Error::MaxRetriesReached)); }; - let secret_url = self.url.clone(); + let secret_url = self + .last_url + .as_ref() + .expect("should have last URL if we receive connection error") + .clone(); let user_agent = self.user_agent.clone(); let socket_factory = self.socket_factory.clone(); @@ -736,22 +746,20 @@ impl PhoenixMessage { } // This is basically the same as tungstenite does but we add some new headers (namely user-agent) -fn make_request(url: Secret, user_agent: String) -> Request { - use secrecy::ExposeSecret as _; - +fn make_request(url: Url, user_agent: String) -> Request { let mut r = [0u8; 16]; OsRng.fill_bytes(&mut r); let key = base64::engine::general_purpose::STANDARD.encode(r); Request::builder() .method("GET") - .header("Host", url.expose_secret().host()) + .header("Host", url.host().unwrap().to_string()) .header("Connection", "Upgrade") .header("Upgrade", "websocket") .header("Sec-WebSocket-Version", "13") .header("Sec-WebSocket-Key", key) .header("User-Agent", user_agent) - .uri(url.expose_secret().inner().as_str()) + .uri(url.to_string()) .body(()) .expect("building static request always works") } diff --git a/rust/phoenix-channel/src/login_url.rs b/rust/phoenix-channel/src/login_url.rs index a47d9754b..2c2a1e7a0 100644 --- a/rust/phoenix-channel/src/login_url.rs +++ b/rust/phoenix-channel/src/login_url.rs @@ -2,7 +2,11 @@ use base64::{engine::general_purpose::STANDARD, Engine}; use secrecy::{CloneableSecret, ExposeSecret as _, SecretString, Zeroize}; use serde::Deserialize; use sha2::Digest as _; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::{ + iter, + marker::PhantomData, + net::{Ipv4Addr, Ipv6Addr}, +}; use url::Url; use uuid::Uuid; @@ -27,16 +31,18 @@ pub struct DeviceInfo { } #[derive(Clone)] -pub struct LoginUrl { +pub struct LoginUrl { url: Url, // Invariant: Must stay the same as the host in `url`. // This is duplicated here because `Url::host` is fallible. // If we don't duplicate it, we'd have to do extra error handling in several places instead of just one place. host: String, + + phantom: PhantomData, } -impl Zeroize for LoginUrl { +impl Zeroize for LoginUrl { fn zeroize(&mut self) { let placeholder = Url::parse("http://a.com") .expect("placeholder URL should always be valid, it's hard-coded"); @@ -44,15 +50,36 @@ impl Zeroize for LoginUrl { } } -impl CloneableSecret for LoginUrl {} +pub struct PublicKeyParam(pub [u8; 32]); -impl LoginUrl { +impl IntoIterator for PublicKeyParam { + type Item = (&'static str, String); + type IntoIter = std::iter::Once; + + fn into_iter(self) -> Self::IntoIter { + iter::once(("public_key", STANDARD.encode(self.0))) + } +} + +pub struct NoParams; + +impl IntoIterator for NoParams { + type Item = (&'static str, String); + type IntoIter = std::iter::Empty; + + fn into_iter(self) -> Self::IntoIter { + std::iter::empty() + } +} + +impl CloneableSecret for LoginUrl where TFinish: Clone {} + +impl LoginUrl { pub fn client( url: impl TryInto, firezone_token: &SecretString, device_id: String, device_name: Option, - public_key: [u8; 32], device_info: DeviceInfo, ) -> Result> { let external_id = hex::encode(sha2::Sha256::digest(device_id)); @@ -64,7 +91,6 @@ impl LoginUrl { url.try_into().map_err(LoginUrlError::InvalidUrl)?, firezone_token, "client", - Some(public_key), Some(external_id), Some(device_name), None, @@ -76,6 +102,7 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } @@ -84,7 +111,6 @@ impl LoginUrl { firezone_token: &SecretString, device_id: String, device_name: Option, - public_key: [u8; 32], ) -> Result> { let external_id = hex::encode(sha2::Sha256::digest(device_id)); let device_name = device_name @@ -95,7 +121,6 @@ impl LoginUrl { url.try_into().map_err(LoginUrlError::InvalidUrl)?, firezone_token, "gateway", - Some(public_key), Some(external_id), Some(device_name), None, @@ -107,9 +132,12 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } +} +impl LoginUrl { pub fn relay( url: impl TryInto, firezone_token: &SecretString, @@ -123,7 +151,6 @@ impl LoginUrl { firezone_token, "relay", None, - None, device_name, Some(listen_port), ipv4_address, @@ -134,15 +161,25 @@ impl LoginUrl { Ok(LoginUrl { host: parse_host(&url)?, url, + phantom: PhantomData, }) } +} - // TODO: Only temporarily public until we delete other phoenix-channel impl. - pub fn inner(&self) -> &Url { - &self.url +impl LoginUrl +where + TFinish: IntoIterator, +{ + pub fn to_url(&self, params: TFinish) -> Url { + let mut url = self.url.clone(); + + url.query_pairs_mut().extend_pairs(params); + + url } +} - // TODO: Only temporarily public until we delete other phoenix-channel impl. +impl LoginUrl { pub fn host(&self) -> &str { &self.host } @@ -190,7 +227,6 @@ fn get_websocket_path( mut api_url: Url, token: &SecretString, mode: &str, - public_key: Option<[u8; 32]>, external_id: Option, name: Option, port: Option, @@ -215,9 +251,6 @@ fn get_websocket_path( query_pairs.clear(); query_pairs.append_pair("token", token.expose_secret()); - if let Some(public_key) = public_key { - query_pairs.append_pair("public_key", &STANDARD.encode(public_key)); - } if let Some(external_id) = external_id { query_pairs.append_pair("external_id", &external_id); } diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index 148dd569c..3dd4c3f08 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -8,7 +8,7 @@ use firezone_relay::{ PeerSocket, Server, Sleep, }; use futures::{future, FutureExt}; -use phoenix_channel::{Event, LoginUrl, PhoenixChannel}; +use phoenix_channel::{Event, LoginUrl, NoParams, PhoenixChannel}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use secrecy::{Secret, SecretString}; @@ -139,7 +139,7 @@ async fn main() -> Result<()> { args.public_ip6_addr, )?; - Some(PhoenixChannel::connect( + let mut channel = PhoenixChannel::disconnected( Secret::new(login), format!("relay/{}", env!("CARGO_PKG_VERSION")), "relay", @@ -150,7 +150,10 @@ async fn main() -> Result<()> { .with_max_elapsed_time(Some(MAX_PARTITION_TIME)) .build(), Arc::new(socket_factory::tcp), - )?) + )?; + channel.connect(NoParams); + + Some(channel) } else { tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode"); @@ -305,7 +308,7 @@ struct Eventloop { sockets: Sockets, server: Server, - channel: Option>, + channel: Option>, sleep: Sleep, sigterm: unix::Signal, @@ -325,7 +328,7 @@ where { fn new( server: Server, - channel: Option>, + channel: Option>, public_address: IpStack, last_heartbeat_sent: Arc>>, ) -> Result {