refactor(connlib): rotate tunnel private key on reset (#6909)

With the new control protocol specified in #6461, the client will no
longer initiate new connections. Instead, the credentials are generated
deterministically by the portal based on the gateway's and the client's
public key. For as long as they use the same public key, they also have
the same in-memory state which makes creating connections idempotent.

What we didn't consider in the new design at first is that when clients
roam, they discard all connections but keep the same private key. As a
result, the portal would generate the same ICE credentials which means
the gateway thinks it can reuse the existing connection when new flows
get authorized. The client however discarded all connections (and
rotated its ports and maybe IPs), meaning the previous candidates sent
to the gateway are no longer valid and connectivity fails.

We fix this by also rotating the private keys upon reset. Rotating the
keys itself isn't enough, we also need to propagate the new public key
all the way "over" to the phoenix channel component which lives
separately from connlib's data plane.

To achieve this, we change `PhoenixChannel` to now start in the
"disconnected" state and require an explicit `connect` call. In
addition, the `LoginUrl` constructed by various components now acts
merely as a "prototype", which may require additional data to construct
a fully valid URL. In the case of client and gateway, this is the public
key of the `Node`. This additional parameter needs to be passed to
`PhoenixChannel` in the `connect` call, thus forming a type-safe
contract that ensures we never attempt to connect without providing a
public key.

For the relay, this doesn't apply.

Lastly, this allows us to tidy up the code a bit by:

a) generating the `Node`'s private key from the existing RNG
b) removing `ConnectArgs` which only had two members left

Related: #6461.
Related: #6732.
This commit is contained in:
Thomas Eizinger
2024-10-08 09:28:51 +11:00
committed by GitHub
parent 754cdf06e7
commit 2d4818e007
19 changed files with 200 additions and 202 deletions

View File

@@ -5,9 +5,7 @@
use crate::tun::Tun;
use backoff::ExponentialBackoffBuilder;
use connlib_client_shared::{
keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList,
};
use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
use connlib_model::{ResourceId, ResourceView};
use ip_network::{Ipv4Network, Ipv6Network};
use jni::{
@@ -355,13 +353,11 @@ fn connect(
handle,
};
let (private_key, public_key) = keypair();
let url = LoginUrl::client(
api_url.as_str(),
&secret,
device_id,
Some(device_name),
public_key.to_bytes(),
device_info,
)?;
@@ -374,13 +370,7 @@ fn connect(
let tcp_socket_factory = Arc::new(protected_tcp_socket_factory(callbacks.clone()));
let args = ConnectArgs {
tcp_socket_factory: tcp_socket_factory.clone(),
udp_socket_factory: Arc::new(protected_udp_socket_factory(callbacks.clone())),
private_key,
callbacks,
};
let portal = PhoenixChannel::connect(
let portal = PhoenixChannel::disconnected(
Secret::new(url),
get_user_agent(Some(os_version), env!("CARGO_PKG_VERSION")),
"client",
@@ -390,7 +380,13 @@ fn connect(
.build(),
tcp_socket_factory,
)?;
let session = Session::connect(args, portal, runtime.handle().clone());
let session = Session::connect(
Arc::new(protected_tcp_socket_factory(callbacks.clone())),
Arc::new(protected_udp_socket_factory(callbacks.clone())),
callbacks,
portal,
runtime.handle().clone(),
);
Ok(SessionWrapper {
inner: session,

View File

@@ -6,9 +6,7 @@ mod tun;
use anyhow::Result;
use backoff::ExponentialBackoffBuilder;
use connlib_client_shared::{
keypair, Callbacks, ConnectArgs, DisconnectError, Session, V4RouteList, V6RouteList,
};
use connlib_client_shared::{Callbacks, DisconnectError, Session, V4RouteList, V6RouteList};
use connlib_model::ResourceView;
use ip_network::{Ipv4Network, Ipv6Network};
use phoenix_channel::get_user_agent;
@@ -198,13 +196,11 @@ impl WrappedSession {
let secret = SecretString::from(token);
let device_info = serde_json::from_str(&device_info).unwrap();
let (private_key, public_key) = keypair();
let url = LoginUrl::client(
api_url.as_str(),
&secret,
device_id,
device_name_override,
public_key.to_bytes(),
device_info,
)?;
@@ -215,15 +211,7 @@ impl WrappedSession {
.build()?;
let _guard = runtime.enter(); // Constructing `PhoenixChannel` requires a runtime context.
let args = ConnectArgs {
private_key,
callbacks: CallbackHandler {
inner: Arc::new(callback_handler),
},
tcp_socket_factory: Arc::new(socket_factory::tcp),
udp_socket_factory: Arc::new(socket_factory::udp),
};
let portal = PhoenixChannel::connect(
let portal = PhoenixChannel::disconnected(
Secret::new(url),
get_user_agent(os_version_override, env!("CARGO_PKG_VERSION")),
"client",
@@ -233,7 +221,15 @@ impl WrappedSession {
.build(),
Arc::new(socket_factory::tcp),
)?;
let session = Session::connect(args, portal, runtime.handle().clone());
let session = Session::connect(
Arc::new(socket_factory::tcp),
Arc::new(socket_factory::udp),
CallbackHandler {
inner: Arc::new(callback_handler),
},
portal,
runtime.handle().clone(),
);
session.set_tun(Box::new(Tun::new()?));
Ok(Self {

View File

@@ -3,7 +3,7 @@ use anyhow::Result;
use connlib_model::ResourceId;
use firezone_tunnel::messages::{client::*, *};
use firezone_tunnel::ClientTunnel;
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel, PublicKeyParam};
use std::{
collections::{BTreeMap, BTreeSet},
io,
@@ -16,7 +16,7 @@ pub struct Eventloop<C: Callbacks> {
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
connection_intents: SentConnectionIntents,
@@ -35,9 +35,11 @@ impl<C: Callbacks> Eventloop<C> {
pub(crate) fn new(
tunnel: ClientTunnel,
callbacks: C,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
mut portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
) -> Self {
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
Self {
tunnel,
portal,
@@ -70,8 +72,9 @@ where
continue;
}
Poll::Ready(Some(Command::Reset)) => {
self.portal.reconnect();
self.tunnel.reset();
self.portal
.connect(PublicKeyParam(self.tunnel.public_key().to_bytes()));
continue;
}

View File

@@ -3,7 +3,6 @@ pub use crate::serde_routelist::{V4RouteList, V6RouteList};
pub use callbacks::{Callbacks, DisconnectError};
pub use connlib_model::StaticSecret;
pub use eventloop::Eventloop;
pub use firezone_tunnel::keypair;
pub use firezone_tunnel::messages::client::{
ResourceDescription, {IngressMessages, ReplyMessages},
};
@@ -12,7 +11,7 @@ use connlib_model::ResourceId;
use eventloop::Command;
use firezone_telemetry as telemetry;
use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
@@ -35,27 +34,26 @@ pub struct Session {
channel: tokio::sync::mpsc::UnboundedSender<Command>,
}
/// Arguments for `connect`, since Clippy said 8 args is too many
pub struct ConnectArgs<CB> {
pub tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
pub udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
pub private_key: StaticSecret,
pub callbacks: CB,
}
impl Session {
/// Creates a new [`Session`].
///
/// This connects to the portal using the given [`LoginUrl`](phoenix_channel::LoginUrl) and creates a wireguard tunnel using the provided private key.
pub fn connect<CB: Callbacks + 'static>(
args: ConnectArgs<CB>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>,
handle: tokio::runtime::Handle,
) -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let callbacks = args.callbacks.clone();
let connect_handle = handle.spawn(connect(args, portal, rx));
let connect_handle = handle.spawn(connect(
tcp_socket_factory,
udp_socket_factory,
callbacks.clone(),
portal,
rx,
));
handle.spawn(connect_supervisor(connect_handle, callbacks));
Self { channel: tx }
@@ -118,22 +116,16 @@ impl Session {
///
/// When this function exits, the tunnel failed unrecoverably and you need to call it again.
async fn connect<CB>(
args: ConnectArgs<CB>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
callbacks: CB,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages, PublicKeyParam>,
rx: UnboundedReceiver<Command>,
) -> Result<(), DisconnectError>
where
CB: Callbacks + 'static,
{
let ConnectArgs {
private_key,
callbacks,
udp_socket_factory,
tcp_socket_factory,
} = args;
let tunnel = ClientTunnel::new(
private_key,
tcp_socket_factory,
udp_socket_factory,
BTreeMap::from([(portal.server_host().to_owned(), portal.resolved_addresses())]),

View File

@@ -123,10 +123,13 @@ where
TId: Eq + Hash + Copy + Ord + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
{
pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self {
pub fn new(seed: [u8; 32]) -> Self {
let mut rng = StdRng::from_seed(seed);
let private_key = StaticSecret::random_from_rng(&mut rng);
let public_key = &(&private_key).into();
Self {
rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key.
rng,
session_id: SessionId::new(*public_key),
private_key,
public_key: *public_key,
@@ -174,6 +177,11 @@ where
self.connections.clear();
self.buffered_transmits.clear();
self.private_key = StaticSecret::random_from_rng(&mut self.rng);
self.public_key = (&self.private_key).into();
self.rate_limiter = Arc::new(RateLimiter::new(&self.public_key, HANDSHAKE_RATE_LIMIT));
self.session_id = SessionId::new(self.public_key);
tracing::debug!(%num_connections, "Closed all connections as part of reconnecting");
}

View File

@@ -1,4 +1,3 @@
use boringtun::x25519::StaticSecret;
use snownet::{Answer, ClientNode, Event, ServerNode};
use std::{
iter,
@@ -73,16 +72,10 @@ fn answer_after_stale_connection_does_not_panic() {
fn only_generate_candidate_event_after_answer() {
let local_candidate = SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 10000);
let mut alice = ClientNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let mut alice = ClientNode::<u64, u64>::new(rand::random());
alice.add_local_host_candidate(local_candidate).unwrap();
let mut bob = ServerNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let mut bob = ServerNode::<u64, u64>::new(rand::random());
let offer = alice.new_connection(1, Instant::now(), Instant::now());
@@ -106,14 +99,8 @@ fn only_generate_candidate_event_after_answer() {
}
fn alice_and_bob() -> (ClientNode<u64, u64>, ServerNode<u64, u64>) {
let alice = ClientNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let bob = ServerNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
rand::random(),
);
let alice = ClientNode::new(rand::random());
let bob = ServerNode::new(rand::random());
(alice, bob)
}

View File

@@ -14,8 +14,8 @@ use crate::peer_store::PeerStore;
use crate::{dns, TunConfig};
use anyhow::Context;
use bimap::BiMap;
use connlib_model::PublicKey;
use connlib_model::{GatewayId, RelayId, ResourceId, ResourceStatus, ResourceView};
use connlib_model::{PublicKey, StaticSecret};
use connlib_model::{Site, SiteId};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
@@ -279,11 +279,7 @@ struct AwaitingConnectionDetails {
}
impl ClientState {
pub(crate) fn new(
private_key: impl Into<StaticSecret>,
known_hosts: BTreeMap<String, Vec<IpAddr>>,
seed: [u8; 32],
) -> Self {
pub(crate) fn new(known_hosts: BTreeMap<String, Vec<IpAddr>>, seed: [u8; 32]) -> Self {
Self {
awaiting_connection_details: Default::default(),
resources_gateways: Default::default(),
@@ -294,7 +290,7 @@ impl ClientState {
buffered_events: Default::default(),
tun_config: Default::default(),
buffered_packets: Default::default(),
node: ClientNode::new(private_key.into(), seed),
node: ClientNode::new(seed),
system_resolvers: Default::default(),
sites_status: Default::default(),
gateways_site: Default::default(),
@@ -374,7 +370,6 @@ impl ClientState {
}
}
#[cfg(all(feature = "proptest", test))]
pub(crate) fn public_key(&self) -> PublicKey {
self.node.public_key()
}
@@ -1523,7 +1518,6 @@ impl IpProvider {
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::OsRng;
#[test]
fn ignores_ip4_igmp_multicast() {
@@ -1568,11 +1562,7 @@ mod tests {
impl ClientState {
pub fn for_test() -> ClientState {
ClientState::new(
StaticSecret::random_from_rng(OsRng),
BTreeMap::new(),
rand::random(),
)
ClientState::new(BTreeMap::new(), rand::random())
}
}

View File

@@ -7,7 +7,7 @@ use crate::{GatewayEvent, GatewayTunnel};
use anyhow::Context;
use boringtun::x25519::PublicKey;
use chrono::{DateTime, Utc};
use connlib_model::{ClientId, DomainName, RelayId, ResourceId, StaticSecret};
use connlib_model::{ClientId, DomainName, RelayId, ResourceId};
use ip_network::{Ipv4Network, Ipv6Network};
use ip_packet::IpPacket;
use secrecy::{ExposeSecret as _, Secret};
@@ -163,16 +163,15 @@ impl DnsResourceNatEntry {
}
impl GatewayState {
pub(crate) fn new(private_key: impl Into<StaticSecret>, seed: [u8; 32]) -> Self {
pub(crate) fn new(seed: [u8; 32]) -> Self {
Self {
peers: Default::default(),
node: ServerNode::new(private_key.into(), seed),
node: ServerNode::new(seed),
next_expiry_resources_check: Default::default(),
buffered_events: VecDeque::default(),
}
}
#[cfg(all(feature = "proptest", test))]
pub(crate) fn public_key(&self) -> PublicKey {
self.node.public_key()
}

View File

@@ -5,7 +5,6 @@
use crate::messages::{Offer, Relay, ResolveRequest, SecretKey};
use bimap::BiMap;
use boringtun::x25519::StaticSecret;
use chrono::Utc;
use connlib_model::{
ClientId, DomainName, GatewayId, PublicKey, RelayId, ResourceId, ResourceView,
@@ -13,7 +12,6 @@ use connlib_model::{
use io::Io;
use ip_network::{Ipv4Network, Ipv6Network};
use ip_packet::MAX_DATAGRAM_PAYLOAD;
use rand::rngs::OsRng;
use socket_factory::{SocketFactory, TcpSocket, UdpSocket};
use std::{
collections::{BTreeMap, BTreeSet},
@@ -81,20 +79,23 @@ pub struct Tunnel<TRoleState> {
impl ClientTunnel {
pub fn new(
private_key: StaticSecret,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
known_hosts: BTreeMap<String, Vec<IpAddr>>,
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
role_state: ClientState::new(private_key, known_hosts, rand::random()),
role_state: ClientState::new(known_hosts, rand::random()),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD),
}
}
pub fn public_key(&self) -> PublicKey {
self.role_state.public_key()
}
pub fn reset(&mut self) {
self.role_state.reset();
self.io.rebind_sockets();
@@ -177,19 +178,22 @@ impl ClientTunnel {
impl GatewayTunnel {
pub fn new(
private_key: StaticSecret,
tcp_socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
udp_socket_factory: Arc<dyn SocketFactory<UdpSocket>>,
) -> Self {
Self {
io: Io::new(tcp_socket_factory, udp_socket_factory),
role_state: GatewayState::new(private_key, rand::random()),
role_state: GatewayState::new(rand::random()),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD),
}
}
pub fn public_key(&self) -> PublicKey {
self.role_state.public_key()
}
pub fn update_relays(&mut self, to_remove: BTreeSet<RelayId>, to_add: Vec<Relay>) {
self.role_state
.update_relays(to_remove, turn(&to_add), Instant::now())
@@ -341,13 +345,6 @@ pub enum GatewayEvent {
},
}
pub fn keypair() -> (StaticSecret, PublicKey) {
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
(private_key, public_key)
}
fn fmt_routes<T>(routes: &BTreeSet<T>, f: &mut fmt::Formatter) -> fmt::Result
where
T: fmt::Display,

View File

@@ -2,9 +2,9 @@ use super::{
composite_strategy::CompositeStrategy, sim_client::*, sim_dns::*, sim_gateway::*, sim_net::*,
strategies::*, stub_portal::StubPortal, transition::*,
};
use crate::{client, DomainName, StaticSecret};
use crate::{client, DomainName};
use crate::{dns::is_subdomain, proptest::relay_id};
use connlib_model::{GatewayId, RelayId, ResourceId};
use connlib_model::{GatewayId, RelayId, ResourceId, StaticSecret};
use domain::base::Rtype;
use proptest::{prelude::*, sample};
use proptest_state_machine::ReferenceStateMachine;

View File

@@ -331,7 +331,7 @@ impl RefClient {
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self) -> SimClient {
let mut client_state = ClientState::new(self.key, self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed.
let mut client_state = ClientState::new(self.known_hosts, self.key.0); // Cheating a bit here by reusing the key as seed.
client_state.update_interface_config(Interface {
ipv4: self.tunnel_ip4,
ipv6: self.tunnel_ip6,

View File

@@ -143,7 +143,7 @@ impl RefGateway {
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self, id: GatewayId) -> SimGateway {
SimGateway::new(id, GatewayState::new(self.key, self.key.0)) // Cheating a bit here by reusing the key as seed.
SimGateway::new(id, GatewayState::new(self.key.0)) // Cheating a bit here by reusing the key as seed.
}
}

View File

@@ -14,7 +14,7 @@ use firezone_tunnel::messages::{
use firezone_tunnel::{DnsResourceNatEntry, GatewayTunnel};
use futures::channel::mpsc;
use futures_bounded::Timeout;
use phoenix_channel::PhoenixChannel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use std::collections::BTreeSet;
use std::convert::Infallible;
use std::net::IpAddr;
@@ -41,7 +41,7 @@ enum ResolveTrigger {
pub struct Eventloop {
tunnel: GatewayTunnel,
portal: PhoenixChannel<(), IngressMessages, ()>,
portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
tun_device_channel: mpsc::Sender<Interface>,
resolve_tasks: futures_bounded::FuturesTupleSet<Vec<IpAddr>, ResolveTrigger>,
@@ -50,9 +50,11 @@ pub struct Eventloop {
impl Eventloop {
pub(crate) fn new(
tunnel: GatewayTunnel,
portal: PhoenixChannel<(), IngressMessages, ()>,
mut portal: PhoenixChannel<(), IngressMessages, (), PublicKeyParam>,
tun_device_channel: mpsc::Sender<Interface>,
) -> Self {
portal.connect(PublicKeyParam(tunnel.public_key().to_bytes()));
Self {
tunnel,
portal,

View File

@@ -2,20 +2,19 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use anyhow::{Context, Result};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_model::StaticSecret;
use firezone_bin_shared::{
http_health_check,
linux::{tcp_socket_factory, udp_socket_factory},
TunDeviceManager,
};
use firezone_tunnel::messages::Interface;
use firezone_tunnel::{keypair, GatewayTunnel, IPV4_PEERS, IPV6_PEERS};
use firezone_tunnel::{GatewayTunnel, IPV4_PEERS, IPV6_PEERS};
use phoenix_channel::get_user_agent;
use phoenix_channel::LoginUrl;
use futures::channel::mpsc;
use futures::{future, StreamExt, TryFutureExt};
use phoenix_channel::PhoenixChannel;
use phoenix_channel::{PhoenixChannel, PublicKeyParam};
use secrecy::{Secret, SecretString};
use std::convert::Infallible;
use std::path::Path;
@@ -55,16 +54,14 @@ async fn try_main() -> Result<()> {
let firezone_id = get_firezone_id(cli.firezone_id).await
.context("Couldn't read FIREZONE_ID or write it to disk: Please provide it through the env variable or provide rw access to /var/lib/firezone/")?;
let (private_key, public_key) = keypair();
let login = LoginUrl::gateway(
cli.api_url,
&SecretString::new(cli.token),
firezone_id,
cli.firezone_name,
public_key.to_bytes(),
)?;
let task = tokio::spawn(run(login, private_key)).err_into();
let task = tokio::spawn(run(login)).err_into();
let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new));
@@ -107,13 +104,9 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
Ok(id)
}
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(
private_key,
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
);
let portal = PhoenixChannel::connect(
async fn run(login: LoginUrl<PublicKeyParam>) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(Arc::new(tcp_socket_factory), Arc::new(udp_socket_factory));
let portal = PhoenixChannel::disconnected(
Secret::new(login),
get_user_agent(None, env!("CARGO_PKG_VERSION")),
PHOENIX_TOPIC,

View File

@@ -4,7 +4,6 @@ use crate::{
};
use anyhow::{bail, Context as _, Result};
use clap::Parser;
use connlib_client_shared::{keypair, ConnectArgs};
use connlib_model::ResourceView;
use firezone_bin_shared::{
platform::{tcp_socket_factory, udp_socket_factory, DnsControlMethod},
@@ -529,14 +528,12 @@ impl<'a> Handler<'a> {
let transaction = firezone_telemetry::start_transaction(ctx);
assert!(self.session.is_none());
let device_id = device_id::get_or_create().map_err(|e| Error::DeviceId(e.to_string()))?;
let (private_key, public_key) = keypair();
let url = LoginUrl::client(
Url::parse(api_url).map_err(|e| Error::UrlParse(e.to_string()))?,
&token,
device_id.id,
None,
public_key.to_bytes(),
device_id::device_info(),
)
.map_err(|e| Error::LoginUrl(e.to_string()))?;
@@ -544,16 +541,10 @@ impl<'a> Handler<'a> {
self.last_connlib_start_instant = Some(Instant::now());
let (cb_tx, cb_rx) = mpsc::channel(1_000);
let callbacks = CallbackHandler { cb_tx };
let args = ConnectArgs {
tcp_socket_factory: Arc::new(tcp_socket_factory),
udp_socket_factory: Arc::new(udp_socket_factory),
private_key,
callbacks,
};
// Synchronous DNS resolution here
let phoenix_span = transaction.start_child("phoenix", "Resolve DNS for PhoenixChannel");
let portal = PhoenixChannel::connect(
let portal = PhoenixChannel::disconnected(
Secret::new(url),
get_user_agent(None, env!("CARGO_PKG_VERSION")),
"client",
@@ -569,7 +560,9 @@ impl<'a> Handler<'a> {
// Read the resolvers before starting connlib, in case connlib's startup interferes.
let dns = self.dns_controller.system_resolvers();
let connlib = connlib_client_shared::Session::connect(
args,
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
callbacks,
portal,
tokio::runtime::Handle::current(),
);

View File

@@ -3,7 +3,7 @@
use anyhow::{anyhow, Context as _, Result};
use backoff::ExponentialBackoffBuilder;
use clap::Parser;
use connlib_client_shared::{keypair, ConnectArgs, Session};
use connlib_client_shared::Session;
use firezone_bin_shared::{
new_dns_notifier, new_network_notifier,
platform::{tcp_socket_factory, udp_socket_factory},
@@ -177,13 +177,11 @@ fn main() -> Result<()> {
};
firezone_telemetry::configure_scope(|scope| scope.set_tag("firezone_id", &firezone_id));
let (private_key, public_key) = keypair();
let url = LoginUrl::client(
cli.api_url,
&token,
firezone_id,
cli.firezone_name,
public_key.to_bytes(),
device_id::device_info(),
)?;
@@ -197,12 +195,6 @@ fn main() -> Result<()> {
// The name matches that in `ipc_service.rs`
let mut last_connlib_start_instant = Some(Instant::now());
let args = ConnectArgs {
udp_socket_factory: Arc::new(udp_socket_factory),
tcp_socket_factory: Arc::new(tcp_socket_factory),
private_key,
callbacks,
};
let result = rt.block_on(async {
let ctx = firezone_telemetry::TransactionContext::new(
@@ -217,7 +209,7 @@ fn main() -> Result<()> {
// for an Internet connection if it launches us at startup.
// When running interactively, it is useful for the user to see that we can't reach the portal.
let phoenix_span = transaction.start_child("phoenix", "Connect PhoenixChannel");
let portal = PhoenixChannel::connect(
let portal = PhoenixChannel::disconnected(
Secret::new(url),
get_user_agent(None, env!("CARGO_PKG_VERSION")),
"client",
@@ -228,7 +220,13 @@ fn main() -> Result<()> {
Arc::new(tcp_socket_factory),
)?;
phoenix_span.finish();
let session = Session::connect(args, portal, rt.handle().clone());
let session = Session::connect(
Arc::new(tcp_socket_factory),
Arc::new(udp_socket_factory),
callbacks,
portal,
rt.handle().clone(),
);
let mut terminate = signals::Terminate::new()?;
let mut hangup = signals::Hangup::new()?;

View File

@@ -3,7 +3,7 @@ mod heartbeat;
mod login_url;
use std::collections::{HashSet, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs as _};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::{fmt, future, marker::PhantomData};
@@ -29,9 +29,9 @@ use tokio_tungstenite::{
use url::{Host, Url};
pub use get_user_agent::get_user_agent;
pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError};
pub use login_url::{DeviceInfo, LoginUrl, LoginUrlError, NoParams, PublicKeyParam};
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish> {
state: State,
waker: Option<Waker>,
pending_messages: VecDeque<String>,
@@ -45,7 +45,8 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
pending_join_requests: HashSet<OutboundRequestId>,
// Stored here to allow re-connecting.
url: Secret<LoginUrl>,
url_prototype: Secret<LoginUrl<TFinish>>,
last_url: Option<Url>,
user_agent: String,
reconnect_backoff: ExponentialBackoff,
@@ -66,7 +67,7 @@ enum State {
impl State {
fn connect(
url: Secret<LoginUrl>,
url: Url,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Self {
@@ -75,11 +76,13 @@ impl State {
}
async fn create_and_connect_websocket(
url: Secret<LoginUrl>,
url: Url,
user_agent: String,
socket_factory: Arc<dyn SocketFactory<TcpSocket>>,
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, InternalError> {
let socket = make_socket(url.expose_secret().inner(), &*socket_factory).await?;
tracing::debug!(host = %url.host().unwrap(), %user_agent, "Connecting to portal");
let socket = make_socket(&url, &*socket_factory).await?;
let (stream, _) = client_async_tls(make_request(url, user_agent), socket)
.await
@@ -214,18 +217,22 @@ impl fmt::Display for OutboundRequestId {
#[error("Cannot close websocket while we are connecting")]
pub struct Connecting;
impl<TInitReq, TInboundMsg, TOutboundRes> PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>
impl<TInitReq, TInboundMsg, TOutboundRes, TFinish>
PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes, TFinish>
where
TInitReq: Serialize + Clone,
TInboundMsg: DeserializeOwned,
TOutboundRes: DeserializeOwned,
TFinish: IntoIterator<Item = (&'static str, String)>,
{
/// Creates a new [PhoenixChannel] to the given endpoint.
/// Creates a new [PhoenixChannel] to the given endpoint in the `disconnected` state.
///
/// You must explicitly call [`PhoenixChannel::connect`] to establish a connection.
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
pub fn connect(
url: Secret<LoginUrl>,
pub fn disconnected(
url: Secret<LoginUrl<TFinish>>,
user_agent: String,
login: &'static str,
init_req: TInitReq,
@@ -239,19 +246,16 @@ where
// We expose them to other components that deal with DNS stuff to ensure our domain always resolves to these IPs.
let resolved_addresses = url
.expose_secret()
.inner()
.socket_addrs(|| None)?
.iter()
.host()
.to_socket_addrs()?
.map(|addr| addr.ip())
.collect();
tracing::debug!(host = %url.expose_secret().host(), %user_agent, "Connecting to portal");
Ok(Self {
reconnect_backoff,
url: url.clone(),
user_agent: user_agent.clone(),
state: State::connect(url, user_agent, socket_factory.clone()),
url_prototype: url,
user_agent,
state: State::Closed,
socket_factory,
waker: None,
pending_messages: Default::default(),
@@ -266,6 +270,7 @@ where
login,
init_req,
resolved_addresses,
last_url: None,
})
}
@@ -276,7 +281,7 @@ where
/// The host we are connecting / connected to.
pub fn server_host(&self) -> &str {
self.url.expose_secret().host()
self.url_prototype.expose_secret().host()
}
/// Join the provided room.
@@ -297,15 +302,16 @@ where
id
}
/// Reconnects to the portal.
pub fn reconnect(&mut self) {
/// Establishes a new connection, dropping the current one if any exists.
pub fn connect(&mut self, params: TFinish) {
// 1. Reset the backoff.
self.reconnect_backoff.reset();
// 2. Set state to `Connecting` without a timer.
let url = self.url.clone();
let url = self.url_prototype.expose_secret().to_url(params);
let user_agent = self.user_agent.clone();
self.state = State::connect(url, user_agent, self.socket_factory.clone());
self.state = State::connect(url.clone(), user_agent, self.socket_factory.clone());
self.last_url = Some(url);
// 3. In case we were already re-connecting, we need to wake the suspended task.
if let Some(waker) = self.waker.take() {
@@ -358,7 +364,7 @@ where
self.heartbeat.reset();
self.state = State::Connected(stream);
let host = self.url.expose_secret().host();
let host = self.url_prototype.expose_secret().host();
tracing::info!(%host, "Connected to portal");
self.join(self.login, self.init_req.clone());
@@ -376,7 +382,11 @@ where
return Poll::Ready(Err(Error::MaxRetriesReached));
};
let secret_url = self.url.clone();
let secret_url = self
.last_url
.as_ref()
.expect("should have last URL if we receive connection error")
.clone();
let user_agent = self.user_agent.clone();
let socket_factory = self.socket_factory.clone();
@@ -736,22 +746,20 @@ impl<T, R> PhoenixMessage<T, R> {
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(url: Secret<LoginUrl>, user_agent: String) -> Request {
use secrecy::ExposeSecret as _;
fn make_request(url: Url, user_agent: String) -> Request {
let mut r = [0u8; 16];
OsRng.fill_bytes(&mut r);
let key = base64::engine::general_purpose::STANDARD.encode(r);
Request::builder()
.method("GET")
.header("Host", url.expose_secret().host())
.header("Host", url.host().unwrap().to_string())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", user_agent)
.uri(url.expose_secret().inner().as_str())
.uri(url.to_string())
.body(())
.expect("building static request always works")
}

View File

@@ -2,7 +2,11 @@ use base64::{engine::general_purpose::STANDARD, Engine};
use secrecy::{CloneableSecret, ExposeSecret as _, SecretString, Zeroize};
use serde::Deserialize;
use sha2::Digest as _;
use std::net::{Ipv4Addr, Ipv6Addr};
use std::{
iter,
marker::PhantomData,
net::{Ipv4Addr, Ipv6Addr},
};
use url::Url;
use uuid::Uuid;
@@ -27,16 +31,18 @@ pub struct DeviceInfo {
}
#[derive(Clone)]
pub struct LoginUrl {
pub struct LoginUrl<TFinish> {
url: Url,
// Invariant: Must stay the same as the host in `url`.
// This is duplicated here because `Url::host` is fallible.
// If we don't duplicate it, we'd have to do extra error handling in several places instead of just one place.
host: String,
phantom: PhantomData<TFinish>,
}
impl Zeroize for LoginUrl {
impl<TFinish> Zeroize for LoginUrl<TFinish> {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com")
.expect("placeholder URL should always be valid, it's hard-coded");
@@ -44,15 +50,36 @@ impl Zeroize for LoginUrl {
}
}
impl CloneableSecret for LoginUrl {}
pub struct PublicKeyParam(pub [u8; 32]);
impl LoginUrl {
impl IntoIterator for PublicKeyParam {
type Item = (&'static str, String);
type IntoIter = std::iter::Once<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
iter::once(("public_key", STANDARD.encode(self.0)))
}
}
pub struct NoParams;
impl IntoIterator for NoParams {
type Item = (&'static str, String);
type IntoIter = std::iter::Empty<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
std::iter::empty()
}
}
impl<TFinish> CloneableSecret for LoginUrl<TFinish> where TFinish: Clone {}
impl LoginUrl<PublicKeyParam> {
pub fn client<E>(
url: impl TryInto<Url, Error = E>,
firezone_token: &SecretString,
device_id: String,
device_name: Option<String>,
public_key: [u8; 32],
device_info: DeviceInfo,
) -> Result<Self, LoginUrlError<E>> {
let external_id = hex::encode(sha2::Sha256::digest(device_id));
@@ -64,7 +91,6 @@ impl LoginUrl {
url.try_into().map_err(LoginUrlError::InvalidUrl)?,
firezone_token,
"client",
Some(public_key),
Some(external_id),
Some(device_name),
None,
@@ -76,6 +102,7 @@ impl LoginUrl {
Ok(LoginUrl {
host: parse_host(&url)?,
url,
phantom: PhantomData,
})
}
@@ -84,7 +111,6 @@ impl LoginUrl {
firezone_token: &SecretString,
device_id: String,
device_name: Option<String>,
public_key: [u8; 32],
) -> Result<Self, LoginUrlError<E>> {
let external_id = hex::encode(sha2::Sha256::digest(device_id));
let device_name = device_name
@@ -95,7 +121,6 @@ impl LoginUrl {
url.try_into().map_err(LoginUrlError::InvalidUrl)?,
firezone_token,
"gateway",
Some(public_key),
Some(external_id),
Some(device_name),
None,
@@ -107,9 +132,12 @@ impl LoginUrl {
Ok(LoginUrl {
host: parse_host(&url)?,
url,
phantom: PhantomData,
})
}
}
impl LoginUrl<NoParams> {
pub fn relay<E>(
url: impl TryInto<Url, Error = E>,
firezone_token: &SecretString,
@@ -123,7 +151,6 @@ impl LoginUrl {
firezone_token,
"relay",
None,
None,
device_name,
Some(listen_port),
ipv4_address,
@@ -134,15 +161,25 @@ impl LoginUrl {
Ok(LoginUrl {
host: parse_host(&url)?,
url,
phantom: PhantomData,
})
}
}
// TODO: Only temporarily public until we delete other phoenix-channel impl.
pub fn inner(&self) -> &Url {
&self.url
impl<TFinish> LoginUrl<TFinish>
where
TFinish: IntoIterator<Item = (&'static str, String)>,
{
pub fn to_url(&self, params: TFinish) -> Url {
let mut url = self.url.clone();
url.query_pairs_mut().extend_pairs(params);
url
}
}
// TODO: Only temporarily public until we delete other phoenix-channel impl.
impl<TFinish> LoginUrl<TFinish> {
pub fn host(&self) -> &str {
&self.host
}
@@ -190,7 +227,6 @@ fn get_websocket_path<E>(
mut api_url: Url,
token: &SecretString,
mode: &str,
public_key: Option<[u8; 32]>,
external_id: Option<String>,
name: Option<String>,
port: Option<u16>,
@@ -215,9 +251,6 @@ fn get_websocket_path<E>(
query_pairs.clear();
query_pairs.append_pair("token", token.expose_secret());
if let Some(public_key) = public_key {
query_pairs.append_pair("public_key", &STANDARD.encode(public_key));
}
if let Some(external_id) = external_id {
query_pairs.append_pair("external_id", &external_id);
}

View File

@@ -8,7 +8,7 @@ use firezone_relay::{
PeerSocket, Server, Sleep,
};
use futures::{future, FutureExt};
use phoenix_channel::{Event, LoginUrl, PhoenixChannel};
use phoenix_channel::{Event, LoginUrl, NoParams, PhoenixChannel};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use secrecy::{Secret, SecretString};
@@ -139,7 +139,7 @@ async fn main() -> Result<()> {
args.public_ip6_addr,
)?;
Some(PhoenixChannel::connect(
let mut channel = PhoenixChannel::disconnected(
Secret::new(login),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
"relay",
@@ -150,7 +150,10 @@ async fn main() -> Result<()> {
.with_max_elapsed_time(Some(MAX_PARTITION_TIME))
.build(),
Arc::new(socket_factory::tcp),
)?)
)?;
channel.connect(NoParams);
Some(channel)
} else {
tracing::warn!(target: "relay", "No portal token supplied, starting standalone mode");
@@ -305,7 +308,7 @@ struct Eventloop<R> {
sockets: Sockets,
server: Server<R>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, ()>>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>>,
sleep: Sleep,
sigterm: unix::Signal,
@@ -325,7 +328,7 @@ where
{
fn new(
server: Server<R>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, ()>>,
channel: Option<PhoenixChannel<JoinMessage, IngressMessage, (), NoParams>>,
public_address: IpStack,
last_heartbeat_sent: Arc<Mutex<Option<Instant>>>,
) -> Result<Self> {