refactor(connlib): introduce LoginUrl component (#4048)

Currently, we are passing a lot of data into `Session::connect`. Half of
this data is only needed to construct the URL we will use to connect to
the portal. We can simplify this by extracting a dedicated `LoginUrl`
component that captures and validates this data early.

Not only does this reduce the number of parameters we pass to
`Session::connect`, it also reduces the number of failure cases we have
to deal with in `Session::connect`. Any time the session fails, we have
to call `onDisconnected` to inform the client. Thus, we should perform
as much validation as we can early on. In other words, once
`Session::connect` returns, the client should be able to expect that the
tunnel is starting.
This commit is contained in:
Thomas Eizinger
2024-03-09 20:35:15 +11:00
committed by GitHub
parent a2f289f0b1
commit fdb33674cd
21 changed files with 389 additions and 329 deletions

9
rust/Cargo.lock generated
View File

@@ -1081,6 +1081,7 @@ dependencies = [
"tracing-android",
"tracing-appender",
"tracing-subscriber",
"url",
]
[[package]]
@@ -1116,6 +1117,7 @@ dependencies = [
"firezone-tunnel",
"ip_network",
"parking_lot",
"phoenix-channel",
"reqwest",
"secrecy",
"serde",
@@ -1146,7 +1148,6 @@ dependencies = [
"futures",
"futures-util",
"hickory-resolver",
"hostname",
"ip_network",
"itertools 0.12.1",
"known-folders",
@@ -1154,6 +1155,7 @@ dependencies = [
"log",
"os_info",
"parking_lot",
"phoenix-channel",
"rand 0.8.5",
"rand_core 0.6.4",
"resolv-conf",
@@ -4387,15 +4389,20 @@ dependencies = [
"backoff",
"base64 0.22.0",
"futures",
"hex",
"hostname",
"libc",
"rand_core 0.6.4",
"secrecy",
"serde",
"serde_json",
"sha2",
"thiserror",
"tokio",
"tokio-tungstenite",
"tracing",
"url",
"uuid",
]
[[package]]

View File

@@ -24,6 +24,7 @@ ip_network = "0.4"
log = "0.4"
serde_json = "1"
thiserror = "1"
url = "2.4.0"
[target.'cfg(target_os = "android")'.dependencies]
tracing-android = "0.2"

View File

@@ -3,7 +3,9 @@
// However, this consideration has made it idiomatic for Java FFI in the Rust
// ecosystem, so it's used here for consistency.
use connlib_client_shared::{file_logger, Callbacks, Error, ResourceDescription, Session};
use connlib_client_shared::{
file_logger, keypair, Callbacks, Error, LoginUrl, LoginUrlError, ResourceDescription, Session,
};
use ip_network::IpNetwork;
use jni::{
objects::{GlobalRef, JByteArray, JClass, JObject, JObjectArray, JString, JValue, JValueGen},
@@ -366,6 +368,8 @@ enum ConnectError {
GetJavaVmFailed(#[source] jni::errors::Error),
#[error(transparent)]
ConnectFailed(#[from] Error),
#[error(transparent)]
InvalidLoginUrl(#[from] LoginUrlError<url::ParseError>),
}
macro_rules! string_from_jstring {
@@ -411,11 +415,18 @@ fn connect(
handle,
};
let session = Session::connect(
let (private_key, public_key) = keypair();
let login = LoginUrl::client(
api_url.as_str(),
secret,
&secret,
device_id,
Some(device_name),
public_key.to_bytes(),
)?;
let session = Session::connect(
login,
private_key,
Some(os_version),
callback_handler,
Some(MAX_PARTITION_TIME),

View File

@@ -1,7 +1,9 @@
// Swift bridge generated code triggers this below
#![allow(clippy::unnecessary_cast, improper_ctypes, non_camel_case_types)]
use connlib_client_shared::{file_logger, Callbacks, Error, ResourceDescription, Session};
use connlib_client_shared::{
file_logger, keypair, Callbacks, Error, LoginUrl, ResourceDescription, Session,
};
use ip_network::IpNetwork;
use secrecy::SecretString;
use std::{
@@ -191,11 +193,19 @@ impl WrappedSession {
) -> Result<Self, String> {
let secret = SecretString::from(token);
let session = Session::connect(
let (private_key, public_key) = keypair();
let login = LoginUrl::client(
api_url.as_str(),
secret,
&secret,
device_id,
device_name_override,
public_key.to_bytes(),
)
.map_err(|e| e.to_string())?;
let session = Session::connect(
login,
private_key,
os_version_override,
CallbackHandler {
inner: Arc::new(callback_handler),

View File

@@ -29,6 +29,8 @@ async-compression = { version = "0.4.6", features = ["tokio", "gzip"] }
parking_lot = "0.12"
bimap = "0.6"
ip_network = { version = "0.4", default-features = false }
phoenix-channel = { workspace = true }
[target.'cfg(target_os = "android")'.dependencies]
tracing = { workspace = true, features = ["std", "attributes"] }

View File

@@ -1,22 +1,21 @@
//! Main connlib library for clients.
pub use connlib_shared::messages::ResourceDescription;
pub use connlib_shared::{Callbacks, Error};
pub use connlib_shared::{keypair, Callbacks, Error, LoginUrl, LoginUrlError};
pub use tracing_appender::non_blocking::WorkerGuard;
use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use connlib_shared::control::SecureUrl;
use connlib_shared::{control::PhoenixChannel, login_url, CallbackErrorFacade, Mode, Result};
use connlib_shared::StaticSecret;
use connlib_shared::{control::PhoenixChannel, CallbackErrorFacade, Result};
use control::ControlPlane;
use firezone_tunnel::Tunnel;
use messages::IngressMessages;
use messages::Messages;
use messages::ReplyMessages;
use secrecy::{Secret, SecretString};
use secrecy::Secret;
use std::future::poll_fn;
use std::time::Duration;
use tokio::time::{Interval, MissedTickBehavior};
use tokio::{runtime::Runtime, time::Instant};
use url::Url;
mod control;
pub mod file_logger;
@@ -67,10 +66,8 @@ where
/// * `device_id` - The cleartext device ID. connlib will obscure this with a hash internally.
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
api_url: impl TryInto<Url>,
token: SecretString,
device_id: String,
device_name_override: Option<String>,
url: LoginUrl,
private_key: StaticSecret,
os_version_override: Option<String>,
callbacks: CB,
max_partition_time: Option<Duration>,
@@ -114,10 +111,8 @@ where
Self::connect_inner(
&runtime,
tx.clone(),
api_url.try_into().map_err(|_| Error::UriError)?,
token,
device_id,
device_name_override,
url,
private_key,
os_version_override,
callbacks.clone(),
max_partition_time,
@@ -139,27 +134,19 @@ where
fn connect_inner(
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
api_url: Url,
token: SecretString,
device_id: String,
device_name_override: Option<String>,
url: LoginUrl,
private_key: StaticSecret,
os_version_override: Option<String>,
callbacks: CallbackErrorFacade<CB>,
max_partition_time: Option<Duration>,
) {
runtime.spawn(async move {
let (connect_url, private_key) = fatal_error!(
login_url(Mode::Client, api_url, token, device_id, device_name_override),
runtime_stopper,
&callbacks
);
// This is kinda hacky, the buffer size is 1 so that we make sure that we
// process one message at a time, blocking if a previous message haven't been processed
// to force queue ordering.
let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(Secret::new(SecureUrl::from_url(connect_url)), os_version_override, move |msg, reference, topic| {
let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(Secret::new(url), os_version_override, move |msg, reference, topic| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!(?msg);

View File

@@ -18,8 +18,6 @@ boringtun = { workspace = true }
chrono = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std", "async-await", "executor"] }
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }
# Hickory already depends on `hostname` so this isn't new
hostname = "0.3.1"
ip_network = { version = "0.4", default-features = false, features = ["serde"] }
os_info = { version = "3", default-features = false }
parking_lot = "0.12"
@@ -43,6 +41,7 @@ libc = "0.2"
dns-lookup = { workspace = true }
known-folders = "1.1.0"
snownet = { workspace = true }
phoenix-channel = { workspace = true }
# Needed for Android logging until tracing is working
log = "0.4"

View File

@@ -21,9 +21,9 @@ use tokio_tungstenite::{
tungstenite::{self, handshake::client::Request},
};
use tungstenite::Message;
use url::Url;
use crate::{get_user_agent, Error, Result};
use phoenix_channel::LoginUrl;
const CHANNEL_SIZE: usize = 1_000;
const HEARTBEAT: Duration = Duration::from_secs(30);
@@ -31,22 +31,6 @@ const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35);
pub type Reference = String;
// TODO: Refactor this PhoenixChannel to use the top-level phoenix-channel crate instead.
// See https://github.com/firezone/firezone/issues/2158
pub struct SecureUrl {
pub inner: Url,
}
impl SecureUrl {
pub fn from_url(url: Url) -> Self {
Self { inner: url }
}
}
impl secrecy::Zeroize for SecureUrl {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com").expect("placeholder URL to be valid");
let _ = std::mem::replace(&mut self.inner, placeholder);
}
}
/// Main struct to interact with the control-protocol channel.
///
/// After creating a new `PhoenixChannel` using [PhoenixChannel::new] you need to
@@ -63,7 +47,7 @@ impl secrecy::Zeroize for SecureUrl {
/// The future returned by [PhoenixChannel::start] will finish when the websocket closes (by an error), meaning that if you
/// `await` it, it will block until you use `close` in a [PhoenixSender], the portal close the connection or something goes wrong.
pub struct PhoenixChannel<F, I, R, M> {
secret_url: Secret<SecureUrl>,
secret_url: Secret<LoginUrl>,
os_version_override: Option<String>,
handler: F,
sender: Sender<Message>,
@@ -73,21 +57,12 @@ pub struct PhoenixChannel<F, I, R, M> {
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(
secret_url: &Secret<SecureUrl>,
secret_url: &Secret<LoginUrl>,
os_version_override: Option<String>,
) -> Result<Request> {
use secrecy::ExposeSecret;
let host = secret_url
.expose_secret()
.inner
.host()
.ok_or(Error::UriError)?;
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
format!("{host}:{port}")
} else {
host.to_string()
};
let host = secret_url.expose_secret().host();
let mut r = [0u8; 16];
OsRng.fill_bytes(&mut r);
@@ -101,7 +76,7 @@ fn make_request(
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", get_user_agent(os_version_override))
.uri(secret_url.expose_secret().inner.as_str())
.uri(secret_url.expose_secret().inner().as_ref())
.body(())?;
Ok(req)
}
@@ -282,7 +257,7 @@ where
///
/// For more info see [struct-level docs][PhoenixChannel].
pub fn new(
secret_url: Secret<SecureUrl>,
secret_url: Secret<LoginUrl>,
os_version_override: Option<String>,
handler: F,
) -> Self {

View File

@@ -20,21 +20,20 @@ pub mod linux;
#[cfg(target_os = "windows")]
pub mod windows;
pub use boringtun::x25519::PublicKey;
pub use boringtun::x25519::StaticSecret;
pub use callbacks::Callbacks;
pub use callbacks_error_facade::CallbackErrorFacade;
pub use error::ConnlibError as Error;
pub use error::Result;
pub use phoenix_channel::{LoginUrl, LoginUrlError};
use boringtun::x25519::{PublicKey, StaticSecret};
use ip_network::Ipv4Network;
use ip_network::Ipv6Network;
use messages::Key;
use ring::digest::{Context, SHA256};
use secrecy::{ExposeSecret, SecretString};
use rand_core::OsRng;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use url::Url;
pub type Dname = domain::base::Dname<Vec<u8>>;
@@ -56,51 +55,11 @@ pub const BUNDLE_ID: &str = "dev.firezone.client";
const VERSION: &str = env!("CARGO_PKG_VERSION");
const LIB_NAME: &str = "connlib";
// From https://man7.org/linux/man-pages/man2/gethostname.2.html
// SUSv2 guarantees that "Host names are limited to 255 bytes".
// POSIX.1 guarantees that "Host names (not including the
// terminating null byte) are limited to HOST_NAME_MAX bytes". On
// Linux, HOST_NAME_MAX is defined with the value 64, which has been
// the limit since Linux 1.0 (earlier kernels imposed a limit of 8
// bytes)
//
// We are counting the nul-byte
#[cfg(not(target_os = "windows"))]
const HOST_NAME_MAX: usize = 256;
pub fn keypair() -> (StaticSecret, PublicKey) {
let private_key = StaticSecret::random_from_rng(OsRng);
let public_key = PublicKey::from(&private_key);
/// Creates a new login URL to use with the portal.
pub fn login_url(
mode: Mode,
api_url: Url,
token: SecretString,
device_id: String,
firezone_name: Option<String>,
) -> Result<(Url, StaticSecret)> {
let private_key = StaticSecret::random_from_rng(rand::rngs::OsRng);
let name = firezone_name
.or(get_host_name())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let external_id = sha256(device_id);
let url = get_websocket_path(
api_url,
token,
match mode {
Mode::Client => "client",
Mode::Gateway => "gateway",
},
&Key(PublicKey::from(&private_key).to_bytes()),
&external_id,
&name,
)?;
Ok((url, private_key))
}
// FIXME: This is a terrible name :(
pub enum Mode {
Client,
Gateway,
(private_key, public_key)
}
pub struct IpProvider {
@@ -208,73 +167,3 @@ fn kernel_version() -> Option<String> {
String::from_utf8(version).ok()
}
#[cfg(not(target_os = "windows"))]
fn get_host_name() -> Option<String> {
let mut buf = [0; HOST_NAME_MAX];
// SAFETY: we just allocated a buffer with that size
if unsafe { libc::gethostname(buf.as_mut_ptr() as *mut _, HOST_NAME_MAX) } != 0 {
return None;
}
String::from_utf8(buf.split(|c| *c == 0).next()?.to_vec()).ok()
}
/// Returns the hostname, or `None` if it's not valid UTF-8
#[cfg(target_os = "windows")]
fn get_host_name() -> Option<String> {
hostname::get().ok().and_then(|x| x.into_string().ok())
}
fn set_ws_scheme(url: &mut Url) -> Result<()> {
let scheme = match url.scheme() {
"http" | "ws" => "ws",
"https" | "wss" => "wss",
_ => return Err(Error::UriScheme),
};
url.set_scheme(scheme)
.expect("Developer error: the match before this should make sure we can set this");
Ok(())
}
fn sha256(input: String) -> String {
let mut ctx = Context::new(&SHA256);
ctx.update(input.as_bytes());
let digest = ctx.finish();
digest.as_ref().iter().fold(String::new(), |mut output, b| {
use std::fmt::Write;
let _ = write!(output, "{b:02x}");
output
})
}
fn get_websocket_path(
mut api_url: Url,
secret: SecretString,
mode: &str,
public_key: &Key,
external_id: &str,
name: &str,
) -> Result<Url> {
set_ws_scheme(&mut api_url)?;
{
let mut paths = api_url.path_segments_mut().map_err(|_| Error::UriError)?;
paths.pop_if_empty();
paths.push(mode);
paths.push("websocket");
}
{
let mut query_pairs = api_url.query_pairs_mut();
query_pairs.clear();
query_pairs.append_pair("token", secret.expose_secret());
query_pairs.append_pair("public_key", &public_key.to_string());
query_pairs.append_pair("external_id", external_id);
query_pairs.append_pair("name", name);
}
Ok(api_url)
}

View File

@@ -2,13 +2,11 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC};
use crate::messages::InitGateway;
use anyhow::{Context, Result};
use backoff::ExponentialBackoffBuilder;
use boringtun::x25519::StaticSecret;
use clap::Parser;
use connlib_shared::{get_user_agent, login_url, Callbacks, Mode};
use connlib_shared::{get_user_agent, keypair, Callbacks, LoginUrl, StaticSecret};
use firezone_cli_utils::{setup_global_subscriber, CommonArgs};
use firezone_tunnel::GatewayTunnel;
use futures::{future, TryFutureExt};
use phoenix_channel::SecureUrl;
use secrecy::{Secret, SecretString};
use std::convert::Infallible;
use std::path::Path;
@@ -16,7 +14,6 @@ use std::pin::pin;
use tokio::io::AsyncWriteExt;
use tokio::signal::ctrl_c;
use tracing_subscriber::layer;
use url::Url;
use uuid::Uuid;
mod eventloop;
@@ -43,15 +40,17 @@ async fn try_main() -> Result<()> {
let firezone_id = get_firezone_id(cli.firezone_id).await
.context("Couldn't read FIREZONE_ID or write it to disk: Please provide it through the env variable or provide rw access to /var/lib/firezone/")?;
let (connect_url, private_key) = login_url(
Mode::Gateway,
let (private_key, public_key) = keypair();
let login = LoginUrl::gateway(
cli.common.api_url,
SecretString::new(cli.common.token),
&SecretString::new(cli.common.token),
firezone_id,
cli.common.firezone_name,
public_key.to_bytes(),
)?;
let task = tokio::spawn(run(connect_url, private_key)).err_into();
let task = tokio::spawn(run(login, private_key)).err_into();
let ctrl_c = pin!(ctrl_c().map_err(anyhow::Error::new));
@@ -89,11 +88,11 @@ async fn get_firezone_id(env_id: Option<String>) -> Result<String> {
Ok(id)
}
async fn run(connect_url: Url, private_key: StaticSecret) -> Result<Infallible> {
async fn run(login: LoginUrl, private_key: StaticSecret) -> Result<Infallible> {
let mut tunnel = GatewayTunnel::new(private_key, CallbackHandler)?;
let (portal, init) = phoenix_channel::init::<_, InitGateway, _, _>(
Secret::new(SecureUrl::from_url(connect_url.clone())),
Secret::new(login),
get_user_agent(None),
PHOENIX_TOPIC,
(),

View File

@@ -1,9 +1,8 @@
//! Fulfills <https://github.com/firezone/firezone/issues/2823>
use crate::client::known_dirs;
use connlib_shared::control::SecureUrl;
use rand::{thread_rng, RngCore};
use secrecy::{ExposeSecret, Secret, SecretString};
use secrecy::{ExposeSecret, SecretString};
use std::path::PathBuf;
use subtle::ConstantTimeEq;
use url::Url;
@@ -64,13 +63,14 @@ pub(crate) struct Request {
}
impl Request {
pub fn to_url(&self, auth_base_url: &Url) -> Secret<SecureUrl> {
pub fn to_url(&self, auth_base_url: &Url) -> SecretString {
let mut url = auth_base_url.clone();
url.query_pairs_mut()
.append_pair("as", "client")
.append_pair("nonce", self.nonce.expose_secret())
.append_pair("state", self.state.expose_secret());
Secret::from(SecureUrl::from_url(url))
SecretString::new(url.to_string())
}
}
@@ -309,9 +309,8 @@ mod tests {
state: bogus_secret("some_state"),
};
assert_eq!(
req.to_url(&auth_base_url).expose_secret().inner,
Url::parse("https://app.firez.one?as=client&nonce=some_nonce&state=some_state")
.unwrap()
req.to_url(&auth_base_url).expose_secret(),
"https://app.firez.one/?as=client&nonce=some_nonce&state=some_state"
);
}

View File

@@ -1,9 +1,9 @@
//! A module for registering, catching, and parsing deep links that are sent over to the app's already-running instance
use crate::client::auth::Response as AuthResponse;
use connlib_shared::control::SecureUrl;
use secrecy::{ExposeSecret, Secret, SecretString};
use secrecy::{ExposeSecret, SecretString};
use std::io;
use url::Url;
pub(crate) const FZ_SCHEME: &str = "firezone-fd0020211111";
@@ -54,8 +54,8 @@ pub enum Error {
pub(crate) use imp::{open, register, Server};
pub(crate) fn parse_auth_callback(url: &Secret<SecureUrl>) -> Option<AuthResponse> {
let url = &url.expose_secret().inner;
pub(crate) fn parse_auth_callback(url: &SecretString) -> Option<AuthResponse> {
let url = Url::parse(url.expose_secret()).ok()?;
match url.host() {
Some(url::Host::Domain("handle_client_sign_in_callback")) => {}
_ => return None,
@@ -105,14 +105,13 @@ pub(crate) fn parse_auth_callback(url: &Secret<SecureUrl>) -> Option<AuthRespons
#[cfg(test)]
mod tests {
use anyhow::Result;
use connlib_shared::control::SecureUrl;
use secrecy::{ExposeSecret, Secret};
use secrecy::{ExposeSecret, SecretString};
#[test]
fn parse_auth_callback() -> Result<()> {
// Positive cases
let input = "firezone://handle_client_sign_in_callback/?actor_name=Reactor+Scram&fragment=a_very_secret_string&state=a_less_secret_string&identity_provider_identifier=12345";
let actual = parse_callback_wrapper(input)?.unwrap();
let actual = parse_callback_wrapper(input).unwrap();
assert_eq!(actual.actor_name, "Reactor Scram");
assert_eq!(actual.fragment.expose_secret(), "a_very_secret_string");
@@ -120,7 +119,7 @@ mod tests {
// Empty string "" `actor_name` is fine
let input = "firezone://handle_client_sign_in_callback/?actor_name=&fragment=&state=&identity_provider_identifier=12345";
let actual = parse_callback_wrapper(input)?.unwrap();
let actual = parse_callback_wrapper(input).unwrap();
assert_eq!(actual.actor_name, "");
assert_eq!(actual.fragment.expose_secret(), "");
@@ -130,24 +129,23 @@ mod tests {
// URL host is wrong
let input = "firezone://not_handle_client_sign_in_callback/?actor_name=Reactor+Scram&fragment=a_very_secret_string&state=a_less_secret_string&identity_provider_identifier=12345";
let actual = parse_callback_wrapper(input)?;
let actual = parse_callback_wrapper(input);
assert!(actual.is_none());
// `actor_name` is not just blank but totally missing
let input = "firezone://handle_client_sign_in_callback/?fragment=&state=&identity_provider_identifier=12345";
let actual = parse_callback_wrapper(input)?;
let actual = parse_callback_wrapper(input);
assert!(actual.is_none());
// URL is nonsense
let input = "?????????";
let actual_result = parse_callback_wrapper(input);
assert!(actual_result.is_err());
assert!(actual_result.is_none());
Ok(())
}
fn parse_callback_wrapper(s: &str) -> Result<Option<super::AuthResponse>> {
let url = Secret::new(SecureUrl::from_url(url::Url::parse(s)?));
Ok(super::parse_auth_callback(&url))
fn parse_callback_wrapper(s: &str) -> Option<super::AuthResponse> {
super::parse_auth_callback(&SecretString::new(s.to_owned()))
}
}

View File

@@ -1,8 +1,7 @@
//! TODO: Not implemented for Linux yet
use super::Error;
use connlib_shared::control::SecureUrl;
use secrecy::Secret;
use secrecy::SecretString;
pub(crate) struct Server {}
@@ -13,7 +12,7 @@ impl Server {
Ok(Self {})
}
pub(crate) async fn accept(self) -> Result<Secret<SecureUrl>, Error> {
pub(crate) async fn accept(self) -> Result<SecretString, Error> {
tracing::warn!("Deep links not implemented yet on Linux");
futures::future::pending().await
}

View File

@@ -1,8 +1,7 @@
//! Placeholder
use super::Error;
use connlib_shared::control::SecureUrl;
use secrecy::Secret;
use secrecy::{Secret, SecretString};
pub(crate) struct Server {}
@@ -13,7 +12,7 @@ impl Server {
Ok(Self {})
}
pub(crate) async fn accept(self) -> Result<Secret<SecureUrl>, Error> {
pub(crate) async fn accept(self) -> Result<SecretString, Error> {
futures::future::pending().await
}
}

View File

@@ -2,7 +2,7 @@
//! Based on reading some of the Windows code from <https://github.com/FabianLars/tauri-plugin-deep-link>, which is licensed "MIT OR Apache-2.0"
use super::{Error, FZ_SCHEME};
use connlib_shared::{control::SecureUrl, BUNDLE_ID};
use connlib_shared::BUNDLE_ID;
use secrecy::{ExposeSecret, Secret, SecretString};
use std::{ffi::c_void, io, path::Path, str::FromStr};
use tokio::{io::AsyncReadExt, io::AsyncWriteExt, net::windows::named_pipe};
@@ -75,7 +75,7 @@ impl Server {
/// I assume this is based on the underlying Windows API.
/// I tried re-using the server and it acted strange. The official Tokio
/// examples are not clear on this.
pub(crate) async fn accept(mut self) -> Result<Secret<SecureUrl>, Error> {
pub(crate) async fn accept(mut self) -> Result<SecretString, Error> {
self.inner
.connect()
.await
@@ -99,9 +99,8 @@ impl Server {
std::str::from_utf8(bytes.expose_secret()).map_err(Error::LinkNotUtf8)?,
)
.expect("Infallible");
let url = Secret::new(SecureUrl::from_url(url::Url::parse(s.expose_secret())?));
Ok(url)
Ok(s)
}
}

View File

@@ -11,8 +11,8 @@ use crate::client::{
use anyhow::{anyhow, bail, Context, Result};
use arc_swap::ArcSwap;
use connlib_client_shared::{file_logger, ResourceDescription};
use connlib_shared::{control::SecureUrl, messages::ResourceId, BUNDLE_ID};
use secrecy::{ExposeSecret, Secret, SecretString};
use connlib_shared::{keypair, messages::ResourceId, LoginUrl, BUNDLE_ID};
use secrecy::{ExposeSecret, SecretString};
use std::{net::IpAddr, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use system_tray_menu::Event as TrayMenuEvent;
use tauri::{Manager, SystemTray, SystemTrayEvent};
@@ -423,7 +423,7 @@ pub(crate) enum ControllerRequest {
},
Fail(Failure),
GetAdvancedSettings(oneshot::Sender<AdvancedSettings>),
SchemeRequest(Secret<SecureUrl>),
SchemeRequest(SecretString),
SystemTrayMenu(TrayMenuEvent),
TunnelReady,
UpdateAvailable(client::updates::Release),
@@ -527,11 +527,17 @@ impl Controller {
api_url = api_url.to_string(),
"Calling connlib Session::connect"
);
let connlib = connlib_client_shared::Session::connect(
api_url,
token,
let (private_key, public_key) = keypair();
let login = LoginUrl::client(
api_url.as_str(),
&token,
self.device_id.clone(),
None, // `get_host_name` over in connlib gets the system's name automatically
None,
public_key.to_bytes(),
)?;
let connlib = connlib_client_shared::Session::connect(
login,
private_key,
None,
callback_handler.clone(),
Some(MAX_PARTITION_TIME),
@@ -564,7 +570,7 @@ impl Controller {
Ok(())
}
async fn handle_deep_link(&mut self, url: &Secret<SecureUrl>) -> Result<()> {
async fn handle_deep_link(&mut self, url: &SecretString) -> Result<()> {
let auth_response =
client::deep_link::parse_auth_callback(url).context("Couldn't parse scheme request")?;
@@ -639,11 +645,7 @@ impl Controller {
if let Some(req) = self.auth.start_sign_in()? {
let url = req.to_url(&self.advanced_settings.auth_base_url);
self.refresh_system_tray_menu()?;
tauri::api::shell::open(
&self.app.shell_scope(),
&url.expose_secret().inner,
None,
)?;
tauri::api::shell::open(&self.app.shell_scope(), url.expose_secret(), None)?;
}
}
Req::SystemTrayMenu(TrayMenuEvent::SignOut) => {

View File

@@ -1,7 +1,11 @@
use anyhow::{Context, Result};
use clap::Parser;
use connlib_client_shared::{file_logger, Callbacks, Session};
use connlib_shared::linux::{etc_resolv_conf, get_dns_control_from_env, DnsControlMethod};
use connlib_shared::{
keypair,
linux::{etc_resolv_conf, get_dns_control_from_env, DnsControlMethod},
LoginUrl,
};
use firezone_cli_utils::{block_on_ctrl_c, setup_global_subscriber, CommonArgs};
use secrecy::SecretString;
use std::{net::IpAddr, path::PathBuf, str::FromStr};
@@ -25,16 +29,17 @@ fn main() -> Result<()> {
None => connlib_shared::device_id::get().context("Could not get `firezone_id` from CLI, could not read it from disk, could not generate it and save it to disk")?,
};
let mut session = Session::connect(
let (private_key, public_key) = keypair();
let login = LoginUrl::client(
cli.common.api_url,
SecretString::from(cli.common.token),
&SecretString::from(cli.common.token),
firezone_id,
None,
None,
callbacks,
max_partition_time,
)
.unwrap();
public_key.to_bytes(),
)?;
let mut session =
Session::connect(login, private_key, None, callbacks, max_partition_time).unwrap();
block_on_ctrl_c();

View File

@@ -20,6 +20,11 @@ thiserror = "1.0.50"
tokio = { version = "1.36.0", features = ["net", "time"] }
backoff = "0.4.0"
anyhow = "1"
uuid = { version = "1.7", default-features = false, features = ["std", "v4"] }
sha2 = "0.10.8"
hex = "0.4"
libc = "0.2"
hostname = "0.3.1" # Hickory already depends on `hostname` so this isn't new
[dev-dependencies]
tokio = { version = "1.36.0", features = ["macros", "rt"] }

View File

@@ -1,4 +1,5 @@
mod heartbeat;
mod login_url;
use std::collections::{HashSet, VecDeque};
use std::{fmt, future, marker::PhantomData};
@@ -10,7 +11,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, SinkExt, StreamExt};
use heartbeat::{Heartbeat, MissedLastHeartbeat};
use rand_core::{OsRng, RngCore};
use secrecy::{CloneableSecret, ExposeSecret as _, Secret};
use secrecy::{ExposeSecret as _, Secret};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::task::{ready, Context, Poll};
use tokio::net::TcpStream;
@@ -20,7 +21,8 @@ use tokio_tungstenite::{
tungstenite::{handshake::client::Request, Message},
MaybeTlsStream, WebSocketStream,
};
use url::Url;
pub use login_url::{LoginUrl, LoginUrlError};
// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
// See https://github.com/firezone/firezone/issues/2158
@@ -36,7 +38,7 @@ pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
pending_join_requests: HashSet<OutboundRequestId>,
// Stored here to allow re-connecting.
secret_url: Secret<SecureUrl>,
url: Secret<LoginUrl>,
user_agent: String,
reconnect_backoff: ExponentialBackoff,
@@ -57,7 +59,7 @@ enum State {
/// Additionally, you must already provide any query parameters required for authentication.
#[allow(clippy::type_complexity)]
pub async fn init<TInitReq, TInitRes, TInboundMsg, TOutboundRes>(
secret_url: Secret<SecureUrl>,
url: Secret<LoginUrl>,
user_agent: String,
login_topic: &'static str,
payload: TInitReq,
@@ -79,7 +81,7 @@ where
TOutboundRes: DeserializeOwned,
{
let mut channel = PhoenixChannel::<_, InitMessage<TInitRes>, ()>::connect(
secret_url,
url,
user_agent,
login_topic,
payload,
@@ -194,37 +196,6 @@ impl fmt::Display for InboundRequestId {
}
}
#[derive(Clone)]
pub struct SecureUrl {
inner: Url,
}
impl SecureUrl {
pub fn from_url(url: Url) -> Self {
Self { inner: url }
}
/// Exposes the `host` of the URL.
///
/// The host doesn't contain any secrets.
pub fn host(&self) -> Option<&str> {
self.inner.host_str()
}
pub fn port(&self) -> Option<u16> {
self.inner.port()
}
}
impl CloneableSecret for SecureUrl {}
impl secrecy::Zeroize for SecureUrl {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com").expect("placeholder URL to be valid");
let _ = std::mem::replace(&mut self.inner, placeholder);
}
}
impl<TInitReq, TInboundMsg, TOutboundRes> PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes>
where
TInitReq: Serialize + Clone,
@@ -238,7 +209,7 @@ where
///
/// Once the connection is established,
pub fn connect(
secret_url: Secret<SecureUrl>,
url: Secret<LoginUrl>,
user_agent: String,
login: &'static str,
init_req: TInitReq,
@@ -246,10 +217,10 @@ where
) -> Self {
Self {
reconnect_backoff,
secret_url: secret_url.clone(),
url: url.clone(),
user_agent: user_agent.clone(),
state: State::Connecting(Box::pin(async move {
let (stream, _) = connect_async(make_request(secret_url, user_agent))
let (stream, _) = connect_async(make_request(url, user_agent))
.await
.map_err(InternalError::WebSocket)?;
@@ -292,11 +263,7 @@ where
self.reconnect_backoff.reset();
self.state = State::Connected(stream);
let host = self
.secret_url
.expose_secret()
.host()
.expect("always has host");
let host = self.url.expose_secret().host();
tracing::info!(%host, "Connected to portal");
self.join(self.login, self.init_req.clone());
@@ -314,7 +281,7 @@ where
return Poll::Ready(Err(Error::MaxRetriesReached));
};
let secret_url = self.secret_url.clone();
let secret_url = self.url.clone();
let user_agent = self.user_agent.clone();
tracing::debug!(?backoff, max_elapsed_time = ?self.reconnect_backoff.max_elapsed_time, "Reconnecting to portal on transient client error: {e}");
@@ -521,7 +488,7 @@ where
heartbeat: self.heartbeat,
_phantom: PhantomData,
pending_join_requests: self.pending_join_requests,
secret_url: self.secret_url,
url: self.url,
user_agent: self.user_agent,
reconnect_backoff: self.reconnect_backoff,
login: self.login,
@@ -628,33 +595,22 @@ impl<T, R> PhoenixMessage<T, R> {
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Request {
use secrecy::ExposeSecret;
fn make_request(url: Secret<LoginUrl>, user_agent: String) -> Request {
use secrecy::ExposeSecret as _;
let mut r = [0u8; 16];
OsRng.fill_bytes(&mut r);
let key = base64::engine::general_purpose::STANDARD.encode(r);
let mut req_builder = Request::builder()
Request::builder()
.method("GET")
.header("Host", url.expose_secret().host())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", user_agent)
.uri(secret_url.expose_secret().inner.as_str());
if let Some(host) = secret_url.expose_secret().host() {
let host = secret_url
.expose_secret()
.port()
.map(|port| format!("{host}:{port}"))
.unwrap_or(host.to_string());
req_builder = req_builder.header("Host", host);
}
req_builder
.uri(url.expose_secret().inner().as_str())
.body(())
.expect("building static request always works")
}

View File

@@ -0,0 +1,230 @@
use base64::{engine::general_purpose::STANDARD, Engine};
use secrecy::{CloneableSecret, ExposeSecret as _, SecretString, Zeroize};
use sha2::Digest as _;
use std::net::{Ipv4Addr, Ipv6Addr};
use url::Url;
use uuid::Uuid;
// From https://man7.org/linux/man-pages/man2/gethostname.2.html
// SUSv2 guarantees that "Host names are limited to 255 bytes".
// POSIX.1 guarantees that "Host names (not including the
// terminating null byte) are limited to HOST_NAME_MAX bytes". On
// Linux, HOST_NAME_MAX is defined with the value 64, which has been
// the limit since Linux 1.0 (earlier kernels imposed a limit of 8
// bytes)
//
// We are counting the nul-byte
#[cfg(not(target_os = "windows"))]
const HOST_NAME_MAX: usize = 256;
#[derive(Clone)]
pub struct LoginUrl {
url: Url,
// Invariant: Must stay the same as the host in `url`.
// This is duplicated here because `Url::host` is fallible.
// If we don't duplicate it, we'd have to do extra error handling in several places instead of just one place.
host: String,
}
impl Zeroize for LoginUrl {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com")
.expect("placeholder URL should always be valid, it's hard-coded");
let _ = std::mem::replace(&mut self.url, placeholder);
}
}
impl CloneableSecret for LoginUrl {}
impl LoginUrl {
pub fn client<E>(
url: impl TryInto<Url, Error = E>,
firezone_token: &SecretString,
device_id: String,
device_name: Option<String>,
public_key: [u8; 32],
) -> std::result::Result<Self, LoginUrlError<E>> {
let external_id = hex::encode(sha2::Sha256::digest(device_id));
let device_name = device_name
.or(get_host_name())
.unwrap_or_else(|| Uuid::new_v4().to_string());
let url = get_websocket_path(
url.try_into().map_err(LoginUrlError::InvalidUrl)?,
firezone_token,
"client",
Some(public_key),
Some(external_id),
Some(device_name),
None,
None,
)?;
Ok(LoginUrl {
host: parse_host(&url)?,
url,
})
}
pub fn gateway<E>(
url: impl TryInto<Url, Error = E>,
firezone_token: &SecretString,
device_id: String,
device_name: Option<String>,
public_key: [u8; 32],
) -> std::result::Result<Self, LoginUrlError<E>> {
let external_id = hex::encode(sha2::Sha256::digest(device_id));
let device_name = device_name
.or(get_host_name())
.unwrap_or_else(|| Uuid::new_v4().to_string());
let url = get_websocket_path(
url.try_into().map_err(LoginUrlError::InvalidUrl)?,
firezone_token,
"gateway",
Some(public_key),
Some(external_id),
Some(device_name),
None,
None,
)?;
Ok(LoginUrl {
host: parse_host(&url)?,
url,
})
}
pub fn relay<E>(
url: impl TryInto<Url, Error = E>,
firezone_token: &SecretString,
device_name: Option<String>,
ipv4_address: Option<Ipv4Addr>,
ipv6_address: Option<Ipv6Addr>,
) -> std::result::Result<Self, LoginUrlError<E>> {
let url = get_websocket_path(
url.try_into().map_err(LoginUrlError::InvalidUrl)?,
firezone_token,
"relay",
None,
None,
device_name,
ipv4_address,
ipv6_address,
)?;
Ok(LoginUrl {
host: parse_host(&url)?,
url,
})
}
// TODO: Only temporarily public until we delete other phoenix-channel impl.
pub fn inner(&self) -> &Url {
&self.url
}
// TODO: Only temporarily public until we delete other phoenix-channel impl.
pub fn host(&self) -> &str {
&self.host
}
}
/// Parse the host from a URL, including port if present. e.g. `example.com:8080`.
fn parse_host<E>(url: &Url) -> Result<String, LoginUrlError<E>> {
let host = url.host_str().ok_or(LoginUrlError::MissingHost)?;
Ok(match url.port() {
Some(p) => format!("{host}:{p}"),
None => host.to_owned(),
})
}
#[derive(Debug, thiserror::Error)]
pub enum LoginUrlError<E> {
#[error("invalid scheme `{0}`; only http(s) and ws(s) are allowed")]
InvalidUrlScheme(String),
#[error("failed to parse URL: {0}")]
InvalidUrl(E),
#[error("the url is missing a host")]
MissingHost,
}
#[cfg(not(target_os = "windows"))]
fn get_host_name() -> Option<String> {
let mut buf = [0; HOST_NAME_MAX];
// SAFETY: we just allocated a buffer with that size
if unsafe { libc::gethostname(buf.as_mut_ptr() as *mut _, HOST_NAME_MAX) } != 0 {
return None;
}
String::from_utf8(buf.split(|c| *c == 0).next()?.to_vec()).ok()
}
/// Returns the hostname, or `None` if it's not valid UTF-8
#[cfg(target_os = "windows")]
fn get_host_name() -> Option<String> {
hostname::get().ok().and_then(|x| x.into_string().ok())
}
#[allow(clippy::too_many_arguments)]
fn get_websocket_path<E>(
mut api_url: Url,
token: &SecretString,
mode: &str,
public_key: Option<[u8; 32]>,
external_id: Option<String>,
name: Option<String>,
ipv4_address: Option<Ipv4Addr>,
ipv6_address: Option<Ipv6Addr>,
) -> std::result::Result<Url, LoginUrlError<E>> {
set_ws_scheme(&mut api_url)?;
{
let mut paths = api_url
.path_segments_mut()
.expect("scheme guarantees valid URL");
paths.pop_if_empty();
paths.push(mode);
paths.push("websocket");
}
{
let mut query_pairs = api_url.query_pairs_mut();
query_pairs.clear();
query_pairs.append_pair("token", token.expose_secret());
if let Some(public_key) = public_key {
query_pairs.append_pair("public_key", &STANDARD.encode(public_key));
}
if let Some(external_id) = external_id {
query_pairs.append_pair("external_id", &external_id);
}
if let Some(name) = name {
query_pairs.append_pair("name", &name);
}
if let Some(ipv4_address) = ipv4_address {
query_pairs.append_pair("ipv4", &ipv4_address.to_string());
}
if let Some(ipv4_address) = ipv6_address {
query_pairs.append_pair("ipv6", &ipv4_address.to_string());
}
}
Ok(api_url)
}
fn set_ws_scheme<E>(url: &mut Url) -> std::result::Result<(), LoginUrlError<E>> {
let scheme = match url.scheme() {
"http" | "ws" => "ws",
"https" | "wss" => "wss",
other => return Err(LoginUrlError::InvalidUrlScheme(other.to_owned())),
};
url.set_scheme(scheme)
.expect("Developer error: the match before this should make sure we can set this");
Ok(())
}

View File

@@ -9,7 +9,7 @@ use futures::channel::mpsc;
use futures::{future, FutureExt, SinkExt, StreamExt};
use opentelemetry::{sdk, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use phoenix_channel::{Event, PhoenixChannel, SecureUrl};
use phoenix_channel::{Event, LoginUrl, PhoenixChannel};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use secrecy::{Secret, SecretString};
@@ -250,33 +250,21 @@ fn env_filter() -> EnvFilter {
async fn connect_to_portal(
args: &Args,
token: &SecretString,
mut url: Url,
url: Url,
stamp_secret: &SecretString,
) -> Result<Option<PhoenixChannel<JoinMessage, (), ()>>> {
use secrecy::ExposeSecret;
if !url.path().is_empty() {
tracing::warn!(target: "relay", "Overwriting path component of portal URL with '/relay/websocket'");
}
url.set_path("relay/websocket");
url.query_pairs_mut()
.append_pair("token", token.expose_secret().as_str());
if let Some(public_ip4_addr) = args.public_ip4_addr {
url.query_pairs_mut()
.append_pair("ipv4", &public_ip4_addr.to_string());
}
if let Some(public_ip6_addr) = args.public_ip6_addr {
url.query_pairs_mut()
.append_pair("ipv6", &public_ip6_addr.to_string());
}
if let Some(name) = args.name.as_ref() {
url.query_pairs_mut().append_pair("name", name);
}
let login = LoginUrl::relay(
url,
token,
args.name.clone(),
args.public_ip4_addr,
args.public_ip6_addr,
)?;
let (channel, Init {}) = phoenix_channel::init::<_, Init, _, _>(
Secret::from(SecureUrl::from_url(url)),
Secret::new(login),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
"relay",
JoinMessage {