feat(connlib): Wrap secrets in Secret to minimize chance of leakage (#2159)

Fixes #2085

---------

Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Jamil
2023-09-28 11:35:16 -07:00
committed by GitHub
parent 0ceecc0c0e
commit 3baf2ee1bb
31 changed files with 233 additions and 83 deletions

21
rust/Cargo.lock generated
View File

@@ -739,6 +739,7 @@ dependencies = [
"ip_network",
"jni",
"log",
"secrecy",
"serde_json",
"thiserror",
"tracing",
@@ -756,6 +757,7 @@ dependencies = [
"firezone-client-connlib",
"ip_network",
"libc",
"secrecy",
"serde_json",
"swift-bridge",
"swift-bridge-build",
@@ -1152,6 +1154,7 @@ dependencies = [
"firezone-tunnel",
"libs-common",
"rand",
"secrecy",
"serde",
"serde_json",
"tokio",
@@ -1176,6 +1179,7 @@ dependencies = [
"firezone-tunnel",
"libs-common",
"rand",
"secrecy",
"serde",
"serde_json",
"tokio",
@@ -1208,6 +1212,7 @@ dependencies = [
"pnet_packet",
"rand_core",
"rtnetlink",
"secrecy",
"serde",
"thiserror",
"tokio",
@@ -1329,6 +1334,7 @@ dependencies = [
"clap",
"firezone-gateway-connlib",
"headless-utils",
"secrecy",
"tracing",
"tracing-subscriber",
]
@@ -1429,6 +1435,7 @@ dependencies = [
"clap",
"firezone-client-connlib",
"headless-utils",
"secrecy",
"tracing",
"tracing-subscriber",
]
@@ -1841,6 +1848,7 @@ dependencies = [
"rand_core",
"ring",
"rtnetlink",
"secrecy",
"serde",
"serde_json",
"smbios-lib",
@@ -2395,6 +2403,7 @@ dependencies = [
"base64 0.21.4",
"futures",
"rand_core",
"secrecy",
"serde",
"serde_json",
"thiserror",
@@ -2738,6 +2747,7 @@ dependencies = [
"proptest",
"rand",
"redis",
"secrecy",
"serde",
"sha2",
"socket2 0.5.4",
@@ -3023,6 +3033,17 @@ dependencies = [
"zeroize",
]
[[package]]
name = "secrecy"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e"
dependencies = [
"bytes",
"serde",
"zeroize",
]
[[package]]
name = "security-framework"
version = "2.9.2"

View File

@@ -20,6 +20,7 @@ swift-bridge = "0.1.52"
backoff = { version = "0.4", features = ["tokio"] }
tracing = { version = "0.1.37" }
tracing-subscriber = { version = "0.3.17", features = ["parking_lot"] }
secrecy = "0.8"
# Patched to use https://github.com/rust-lang/cc-rs/pull/708
# (the `patch` section can't be used for build deps...)

View File

@@ -12,6 +12,7 @@ doc = false
mock = ["firezone-client-connlib/mock"]
[dependencies]
secrecy = { workspace = true }
tracing-android = "0.2"
tracing = { workspace = true, features = ["std", "attributes"] }
tracing-subscriber = { workspace = true }

View File

@@ -10,6 +10,7 @@ use jni::{
strings::JNIString,
JNIEnv, JavaVM,
};
use secrecy::SecretString;
use std::sync::OnceLock;
use std::{
net::{Ipv4Addr, Ipv6Addr},
@@ -343,7 +344,7 @@ fn connect(
callback_handler: GlobalRef,
) -> Result<Session<CallbackHandler>, ConnectError> {
let portal_url = string_from_jstring!(env, portal_url);
let portal_token = string_from_jstring!(env, portal_token);
let secret = SecretString::from(string_from_jstring!(env, portal_token));
let device_id = string_from_jstring!(env, device_id);
let log_dir = string_from_jstring!(env, log_dir);
let log_filter = string_from_jstring!(env, log_filter);
@@ -354,12 +355,7 @@ fn connect(
init_logging(log_dir.into(), log_filter);
let session = Session::connect(
portal_url.as_str(),
portal_token,
device_id,
callback_handler,
)?;
let session = Session::connect(portal_url.as_str(), secret, device_id, callback_handler)?;
Ok(session)
}

View File

@@ -13,6 +13,7 @@ swift-bridge-build = "0.1.52"
walkdir = "2.3.3"
[dependencies]
secrecy = { workspace = true }
ip_network = "0.4"
libc = "0.2"
swift-bridge = { workspace = true }

View File

@@ -3,6 +3,7 @@
use firezone_client_connlib::{file_logger, Callbacks, Error, ResourceDescription, Session};
use ip_network::IpNetwork;
use secrecy::SecretString;
use std::{
net::{Ipv4Addr, Ipv6Addr},
os::fd::RawFd,
@@ -162,10 +163,11 @@ impl WrappedSession {
callback_handler: ffi::CallbackHandler,
) -> Result<Self, String> {
let _guard = init_logging(log_dir.into(), log_filter);
let secret = SecretString::from(token);
let session = Session::connect(
portal_url.as_str(),
token,
secret,
device_id,
CallbackHandler(callback_handler.into()),
)

View File

@@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
secrecy = { workspace = true }
firezone-client-connlib = { path = "../../libs/client" }
headless-utils = { path = "../../headless-utils" }
anyhow = { version = "1.0" }

View File

@@ -2,6 +2,7 @@ use anyhow::Result;
use clap::Parser;
use firezone_client_connlib::{file_logger, get_device_id, Callbacks, Session};
use headless_utils::{block_on_ctrl_c, setup_global_subscriber, CommonArgs};
use secrecy::SecretString;
use std::path::PathBuf;
fn main() -> Result<()> {
@@ -15,7 +16,7 @@ fn main() -> Result<()> {
let mut session = Session::connect(
cli.common.url,
cli.common.secret,
SecretString::from(cli.common.secret),
device_id,
CallbackHandler,
)

View File

@@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
secrecy = { workspace = true }
firezone-gateway-connlib = { path = "../libs/gateway" }
headless-utils = { path = "../headless-utils" }
anyhow = { version = "1.0" }

View File

@@ -2,6 +2,7 @@ use anyhow::Result;
use clap::Parser;
use firezone_gateway_connlib::{get_device_id, Callbacks, Session};
use headless_utils::{block_on_ctrl_c, setup_global_subscriber, CommonArgs};
use secrecy::SecretString;
use tracing_subscriber::layer;
fn main() -> Result<()> {
@@ -11,7 +12,7 @@ fn main() -> Result<()> {
let device_id = get_device_id();
let mut session = Session::connect(
cli.common.url,
cli.common.secret,
SecretString::from(cli.common.secret),
device_id,
CallbackHandler,
)

View File

@@ -7,6 +7,7 @@ edition = "2021"
mock = ["libs-common/mock"]
[dependencies]
secrecy = { workspace = true }
tokio = { version = "1.32", default-features = false, features = ["sync"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }

View File

@@ -8,6 +8,7 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use boringtun::x25519::{PublicKey, StaticSecret};
use control::ControlPlane;
use firezone_tunnel::Tunnel;
use libs_common::control::SecureUrl;
use libs_common::{
control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result,
};
@@ -15,6 +16,7 @@ use messages::IngressMessages;
use messages::Messages;
use messages::ReplyMessages;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use secrecy::{Secret, SecretString};
use std::sync::Arc;
use std::time::Duration;
use tokio::{runtime::Runtime, sync::Mutex};
@@ -63,7 +65,7 @@ where
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
portal_url: impl TryInto<Url>,
token: String,
token: SecretString,
device_id: String,
callbacks: CB,
) -> Result<Self> {
@@ -120,7 +122,7 @@ where
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
portal_url: Url,
token: String,
token: SecretString,
device_id: String,
callbacks: CallbackErrorFacade<CB>,
) {
@@ -140,7 +142,7 @@ where
// to force queue ordering.
let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(connect_url, move |msg, reference| {
let mut connection = PhoenixChannel::<_, IngressMessages, ReplyMessages, Messages>::new(Secret::new(SecureUrl::from_url(connect_url)), move |msg, reference| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!("Received message: {msg:?}");

View File

@@ -8,6 +8,7 @@ edition = "2021"
mock = []
[dependencies]
secrecy = { workspace = true, features = ["serde", "bytes"] }
base64 = { version = "0.21", default-features = false, features = ["std"] }
boringtun = { workspace = true }
chrono = { workspace = true }

View File

@@ -13,6 +13,7 @@ use futures::{
};
use futures_util::{Future, SinkExt, StreamExt, TryFutureExt};
use rand_core::{OsRng, RngCore};
use secrecy::Secret;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio_stream::StreamExt as _;
use tokio_tungstenite::{
@@ -30,6 +31,22 @@ const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(35);
pub type Reference = String;
// TODO: Refactor this PhoenixChannel to use the top-level phoenix-channel crate instead.
// See https://github.com/firezone/firezone/issues/2158
pub struct SecureUrl {
inner: Url,
}
impl SecureUrl {
pub fn from_url(url: Url) -> Self {
Self { inner: url }
}
}
impl secrecy::Zeroize for SecureUrl {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com").expect("placeholder URL to be valid");
let _ = std::mem::replace(&mut self.inner, placeholder);
}
}
/// Main struct to interact with the control-protocol channel.
///
/// After creating a new `PhoenixChannel` using [PhoenixChannel::new] you need to
@@ -46,7 +63,7 @@ pub type Reference = String;
/// The future returned by [PhoenixChannel::start] will finish when the websocket closes (by an error), meaning that if you
/// `await` it, it will block until you use `close` in a [PhoenixSender], the portal close the connection or something goes wrong.
pub struct PhoenixChannel<F, I, R, M> {
uri: Url,
secret_url: Secret<SecureUrl>,
handler: F,
sender: Sender<Message>,
receiver: Receiver<Message>,
@@ -54,9 +71,15 @@ pub struct PhoenixChannel<F, I, R, M> {
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(uri: &Url) -> Result<Request> {
let host = uri.host().ok_or(Error::UriError)?;
let host = if let Some(port) = uri.port() {
fn make_request(secret_url: &Secret<SecureUrl>) -> Result<Request> {
use secrecy::ExposeSecret;
let host = secret_url
.expose_secret()
.inner
.host()
.ok_or(Error::UriError)?;
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
format!("{host}:{port}")
} else {
host.to_string()
@@ -74,7 +97,7 @@ fn make_request(uri: &Url) -> Result<Request> {
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", get_user_agent())
.uri(uri.as_str())
.uri(secret_url.expose_secret().inner.as_str())
.body(())?;
Ok(req)
}
@@ -102,9 +125,9 @@ where
topics: Vec<String>,
after_connection_ends: impl FnOnce(),
) -> Result<()> {
tracing::trace!("Trying to connect to portal URL {}...", self.uri);
tracing::trace!("Trying to connect to portal...");
let (ws_stream, _) = connect_async(make_request(&self.uri)?).await?;
let (ws_stream, _) = connect_async(make_request(&self.secret_url)?).await?;
tracing::trace!("Successfully connected to portal");
@@ -240,17 +263,17 @@ where
/// Creates a new [PhoenixChannel] not started yet.
///
/// # Parameters:
/// - `uri`: Portal's websocket uri
/// - `secret_url`: Portal's websocket uri
/// - `handler`: The handle that will be called for each received message.
///
/// For more info see [struct-level docs][PhoenixChannel].
pub fn new(uri: Url, handler: F) -> Self {
pub fn new(secret_url: Secret<SecureUrl>, handler: F) -> Self {
let (sender, receiver) = channel(CHANNEL_SIZE);
Self {
sender,
receiver,
uri,
secret_url,
handler,
_phantom: PhantomData,
}

View File

@@ -16,6 +16,7 @@ pub use error::Result;
use messages::Key;
use ring::digest::{Context, SHA256};
use secrecy::{ExposeSecret, SecretString};
use std::net::Ipv4Addr;
use url::Url;
@@ -79,7 +80,7 @@ pub fn sha256(input: String) -> String {
pub fn get_websocket_path(
mut url: Url,
secret: String,
secret: SecretString,
mode: &str,
public_key: &Key,
external_id: &str,
@@ -97,7 +98,7 @@ pub fn get_websocket_path(
{
let mut query_pairs = url.query_pairs_mut();
query_pairs.clear();
query_pairs.append_pair("token", &secret);
query_pairs.append_pair("token", secret.expose_secret());
query_pairs.append_pair("public_key", &public_key.to_string());
query_pairs.append_pair("external_id", external_id);
query_pairs.append_pair("name_suffix", name_suffix);

View File

@@ -10,7 +10,7 @@ use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
mod key;
pub use key::Key;
pub use key::{Key, SecretKey};
#[derive(Hash, Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)]
pub struct GatewayId(Uuid);
@@ -44,7 +44,7 @@ impl fmt::Display for ResourceId {
}
/// Represents a wireguard peer.
#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Peer {
/// Keepalive: How often to send a keep alive message.
pub persistent_keepalive: Option<u16>,
@@ -55,13 +55,22 @@ pub struct Peer {
/// Peer's Ipv6 (only 1 ipv6 per peer for now and mandatory).
pub ipv6: Ipv6Addr,
/// Preshared key for the given peer.
pub preshared_key: Key,
pub preshared_key: SecretKey,
}
impl PartialEq for Peer {
fn eq(&self, other: &Self) -> bool {
self.persistent_keepalive.eq(&other.persistent_keepalive)
&& self.public_key.eq(&other.public_key)
&& self.ipv4.eq(&other.ipv4)
&& self.ipv6.eq(&other.ipv6)
}
}
/// Represent a connection request from a client to a given resource.
///
/// While this is a client-only message it's hosted in common since the tunnel
/// make use of this message type.
/// makes use of this message type.
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct RequestConnection {
/// Gateway id for the connection
@@ -69,7 +78,7 @@ pub struct RequestConnection {
/// Resource id the request is for.
pub resource_id: ResourceId,
/// The preshared key the client generated for the connection that it is trying to establish.
pub client_preshared_key: Key,
pub client_preshared_key: SecretKey,
/// Client's local RTC Session Description that the client will use for this connection.
pub client_rtc_session_description: RTCSessionDescription,
}
@@ -90,7 +99,6 @@ pub struct ReuseConnection {
impl PartialEq for RequestConnection {
fn eq(&self, other: &Self) -> bool {
self.resource_id == other.resource_id
&& self.client_preshared_key == other.client_preshared_key
}
}

View File

@@ -1,5 +1,6 @@
use base64::{display::Base64Display, engine::general_purpose::STANDARD, Engine};
use boringtun::x25519::PublicKey;
use secrecy::{CloneableSecret, DebugSecret, Secret, SerializableSecret, Zeroize};
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::{fmt, str::FromStr};
@@ -14,6 +15,8 @@ const KEY_SIZE: usize = 32;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct Key(pub [u8; KEY_SIZE]);
impl DebugSecret for Key {}
impl FromStr for Key {
type Err = Error;
@@ -38,6 +41,12 @@ impl FromStr for Key {
}
}
impl Zeroize for Key {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
impl From<PublicKey> for Key {
fn from(value: PublicKey) -> Self {
Self(value.to_bytes())
@@ -69,6 +78,11 @@ impl Serialize for Key {
}
}
impl CloneableSecret for Key {}
impl SerializableSecret for Key {}
pub type SecretKey = Secret<Key>;
#[cfg(test)]
mod test {
use boringtun::x25519::{PublicKey, StaticSecret};
@@ -86,7 +100,7 @@ mod test {
#[test]
fn can_serialize_from_private_key_and_back() {
let private_key = StaticSecret::random_from_rng(OsRng);
let expected_public_key = PublicKey::from(&private_key);
let expected_public_key = PublicKey::from(private_key.to_bytes());
let public_key = Key(expected_public_key.to_bytes());
let public_key_string = serde_json::to_string(&public_key).unwrap();
let actual_key: Key = serde_json::from_str(&public_key_string).unwrap();

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
secrecy = { workspace = true }
libs-common = { path = "../common" }
async-trait = { version = "0.1", default-features = false }
firezone-tunnel = { path = "../tunnel" }

View File

@@ -6,11 +6,13 @@ use backoff::{backoff::Backoff, ExponentialBackoffBuilder};
use boringtun::x25519::{PublicKey, StaticSecret};
use control::ControlPlane;
use firezone_tunnel::Tunnel;
use libs_common::control::SecureUrl;
use libs_common::{
control::PhoenixChannel, get_websocket_path, messages::Key, sha256, CallbackErrorFacade, Result,
};
use messages::IngressMessages;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use secrecy::{Secret, SecretString};
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime::Runtime;
@@ -58,7 +60,7 @@ where
// TODO: token should be something like SecretString but we need to think about FFI compatibility
pub fn connect(
portal_url: impl TryInto<Url>,
token: String,
token: SecretString,
device_id: String,
callbacks: CB,
) -> Result<Self> {
@@ -115,7 +117,7 @@ where
runtime: &Runtime,
runtime_stopper: tokio::sync::mpsc::Sender<StopRuntime>,
portal_url: Url,
token: String,
token: SecretString,
device_id: String,
callbacks: CallbackErrorFacade<CB>,
) {
@@ -136,7 +138,7 @@ where
// to force queue ordering.
let (control_plane_sender, mut control_plane_receiver) = tokio::sync::mpsc::channel(1);
let mut connection = PhoenixChannel::<_, IngressMessages, IngressMessages, IngressMessages>::new(connect_url, move |msg, reference| {
let mut connection = PhoenixChannel::<_, IngressMessages, IngressMessages, IngressMessages>::new(Secret::new(SecureUrl::from_url(connect_url)), move |msg, reference| {
let control_plane_sender = control_plane_sender.clone();
async move {
tracing::trace!("Received message: {msg:?}");

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
secrecy = { workspace = true }
async-trait = { version = "0.1", default-features = false }
tokio = { version = "1.32", default-features = false, features = ["rt", "rt-multi-thread", "sync"] }
thiserror = { version = "1.0", default-features = false }

View File

@@ -1,5 +1,6 @@
use boringtun::noise::Tunn;
use chrono::{DateTime, Utc};
use secrecy::ExposeSecret;
use std::sync::Arc;
use tracing::instrument;
@@ -92,7 +93,7 @@ where
let tunn = Tunn::new(
self.private_key.clone(),
peer_config.public_key,
Some(peer_config.preshared_key.to_bytes()),
Some(peer_config.preshared_key.expose_secret().0),
peer_config.persistent_keepalive,
index,
None,

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use boringtun::x25519::{PublicKey, StaticSecret};
use chrono::{DateTime, Utc};
use libs_common::messages::SecretKey;
use libs_common::{
control::Reference,
messages::{
@@ -11,6 +12,7 @@ use libs_common::{
Callbacks,
};
use rand_core::OsRng;
use secrecy::Secret;
use webrtc::{
data_channel::data_channel_init::RTCDataChannelInit,
peer_connection::{
@@ -209,7 +211,7 @@ where
persistent_keepalive: None,
public_key: gateway_public_key,
ips: resource_description.ips(),
preshared_key: p_key,
preshared_key: SecretKey::new(Key(p_key.to_bytes())),
};
if let Err(e) = tunnel
@@ -237,7 +239,7 @@ where
Ok(Request::NewConnection(RequestConnection {
resource_id,
gateway_id,
client_preshared_key: Key(preshared_key.to_bytes()),
client_preshared_key: Secret::new(Key(preshared_key.to_bytes())),
client_rtc_session_description: offer,
}))
}

View File

@@ -44,6 +44,7 @@ pub use control_protocol::Request;
pub use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use index::IndexLfsr;
use libs_common::messages::SecretKey;
mod control_protocol;
mod device_channel;
@@ -96,7 +97,7 @@ pub struct PeerConfig {
pub(crate) persistent_keepalive: Option<u16>,
pub(crate) public_key: PublicKey,
pub(crate) ips: Vec<IpNetwork>,
pub(crate) preshared_key: StaticSecret,
pub(crate) preshared_key: SecretKey,
}
impl From<libs_common::messages::Peer> for PeerConfig {
@@ -105,7 +106,7 @@ impl From<libs_common::messages::Peer> for PeerConfig {
persistent_keepalive: value.persistent_keepalive,
public_key: value.public_key.0.into(),
ips: vec![value.ipv4.into(), value.ipv6.into()],
preshared_key: value.preshared_key.0.into(),
preshared_key: value.preshared_key,
}
}
}

View File

@@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
secrecy = { workspace = true }
tokio-tungstenite = { version = "0.19.0", features = ["rustls-tls-native-roots"] }
futures = "0.3.28"
base64 = "0.21.4"

View File

@@ -4,6 +4,7 @@ use std::{fmt, marker::PhantomData, time::Duration};
use base64::Engine;
use futures::{FutureExt, SinkExt, StreamExt};
use rand_core::{OsRng, RngCore};
use secrecy::Secret;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -18,6 +19,8 @@ use url::Url;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
// See https://github.com/firezone/firezone/issues/2158
pub struct PhoenixChannel<TInboundMsg, TOutboundRes> {
stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
pending_messages: Vec<Message>,
@@ -60,6 +63,23 @@ impl fmt::Display for InboundRequestId {
}
}
pub struct SecureUrl {
inner: Url,
}
impl SecureUrl {
pub fn from_url(url: Url) -> Self {
Self { inner: url }
}
}
impl secrecy::Zeroize for SecureUrl {
fn zeroize(&mut self) {
let placeholder = Url::parse("http://a.com").expect("placeholder URL to be valid");
let _ = std::mem::replace(&mut self.inner, placeholder);
}
}
impl<TInboundMsg, TOutboundRes> PhoenixChannel<TInboundMsg, TOutboundRes>
where
TInboundMsg: DeserializeOwned,
@@ -69,10 +89,10 @@ where
///
/// The provided URL must contain a host.
/// Additionally, you must already provide any query parameters required for authentication.
pub async fn connect(url: Url, user_agent: String) -> Result<Self, Error> {
pub async fn connect(secret_url: Secret<SecureUrl>, user_agent: String) -> Result<Self, Error> {
tracing::trace!("Trying to connect to the portal...");
let (stream, _) = connect_async(make_request(&url, user_agent)?).await?;
let (stream, _) = connect_async(make_request(secret_url, user_agent)?).await?;
tracing::trace!("Successfully connected to portal");
@@ -299,9 +319,15 @@ impl<T, R> PhoenixMessage<T, R> {
}
// This is basically the same as tungstenite does but we add some new headers (namely user-agent)
fn make_request(uri: &Url, user_agent: String) -> Result<Request, Error> {
let host = uri.host().ok_or(Error::MissingHost)?;
let host = if let Some(port) = uri.port() {
fn make_request(secret_url: Secret<SecureUrl>, user_agent: String) -> Result<Request, Error> {
use secrecy::ExposeSecret;
let host = secret_url
.expose_secret()
.inner
.host()
.ok_or(Error::MissingHost)?;
let host = if let Some(port) = secret_url.expose_secret().inner.port() {
format!("{host}:{port}")
} else {
host.to_string()
@@ -319,7 +345,7 @@ fn make_request(uri: &Url, user_agent: String) -> Result<Request, Error> {
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", key)
.header("User-Agent", user_agent)
.uri(uri.as_str())
.uri(secret_url.expose_secret().inner.as_str())
.body(())
.expect("building static request always works");

View File

@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
secrecy = { workspace = true }
anyhow = "1.0.75"
clap = { version = "4.4.4", features = ["derive", "env"] }
bytecodec = "0.4.15"

View File

@@ -1,6 +1,7 @@
use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use once_cell::sync::Lazy;
use secrecy::{ExposeSecret, SecretString};
use sha2::digest::FixedOutput;
use sha2::Sha256;
use std::borrow::ToOwned;
@@ -14,11 +15,21 @@ use uuid::Uuid;
pub static FIREZONE: Lazy<Realm> = Lazy::new(|| Realm::new("firezone".to_owned()).unwrap());
pub trait MessageIntegrityExt {
fn verify(&self, relay_secret: &str, username: &str, now: SystemTime) -> Result<(), Error>;
fn verify(
&self,
relay_secret: &SecretString,
username: &str,
now: SystemTime,
) -> Result<(), Error>;
}
impl MessageIntegrityExt for MessageIntegrity {
fn verify(&self, relay_secret: &str, username: &str, now: SystemTime) -> Result<(), Error> {
fn verify(
&self,
relay_secret: &SecretString,
username: &str,
now: SystemTime,
) -> Result<(), Error> {
let (expiry_unix_timestamp, salt) = split_username(username)?;
let expired = systemtime_from_unix(expiry_unix_timestamp);
@@ -102,7 +113,7 @@ pub(crate) fn split_username(username: &str) -> Result<(u64, &str), Error> {
}
pub(crate) fn generate_password(
relay_secret: &str,
relay_secret: &SecretString,
expiry: SystemTime,
username_salt: &str,
) -> String {
@@ -117,7 +128,7 @@ pub(crate) fn generate_password(
hasher.update(format!("{expiry_secs}"));
hasher.update(":");
hasher.update(relay_secret);
hasher.update(relay_secret.expose_secret().as_str());
hasher.update(":");
hasher.update(username_salt);
@@ -146,7 +157,7 @@ mod tests {
fn generate_password_test_vector() {
let expiry = systemtime_from_unix(60 * 60 * 24 * 365 * 60);
let password = generate_password(RELAY_SECRET_1, expiry, SAMPLE_USERNAME);
let password = generate_password(&RELAY_SECRET_1.parse().unwrap(), expiry, SAMPLE_USERNAME);
assert_eq!(password, "00hqldgk5xLeKKOB+xls9mHMVtgqzie9DulfgQwMv68")
}
@@ -155,7 +166,7 @@ mod tests {
fn generate_password_test_vector_elixir() {
let expiry = systemtime_from_unix(1685984278);
let password = generate_password(
"1cab293a-4032-46f4-862a-40e5d174b0d2",
&"1cab293a-4032-46f4-862a-40e5d174b0d2".parse().unwrap(),
expiry,
"uvdgKvS9GXYZ_vmv",
);
@@ -164,10 +175,14 @@ mod tests {
#[test]
fn smoke() {
let message_integrity = message_integrity(RELAY_SECRET_1, 1685200000, "n23JJ2wKKtt30oXi");
let message_integrity = message_integrity(
&RELAY_SECRET_1.parse().unwrap(),
1685200000,
"n23JJ2wKKtt30oXi",
);
let result = message_integrity.verify(
RELAY_SECRET_1,
&RELAY_SECRET_1.parse().unwrap(),
"1685200000:n23JJ2wKKtt30oXi",
systemtime_from_unix(1685200000 - 1000),
);
@@ -177,11 +192,14 @@ mod tests {
#[test]
fn expired_is_not_valid() {
let message_integrity =
message_integrity(RELAY_SECRET_1, 1685200000 - 1000, "n23JJ2wKKtt30oXi");
let message_integrity = message_integrity(
&RELAY_SECRET_1.parse().unwrap(),
1685200000 - 1000,
"n23JJ2wKKtt30oXi",
);
let result = message_integrity.verify(
RELAY_SECRET_1,
&RELAY_SECRET_1.parse().unwrap(),
"1685199000:n23JJ2wKKtt30oXi",
systemtime_from_unix(1685200000),
);
@@ -191,10 +209,14 @@ mod tests {
#[test]
fn different_relay_secret_makes_password_invalid() {
let message_integrity = message_integrity(RELAY_SECRET_2, 1685200000, "n23JJ2wKKtt30oXi");
let message_integrity = message_integrity(
&RELAY_SECRET_2.parse().unwrap(),
1685200000,
"n23JJ2wKKtt30oXi",
);
let result = message_integrity.verify(
RELAY_SECRET_1,
&RELAY_SECRET_1.parse().unwrap(),
"1685200000:n23JJ2wKKtt30oXi",
systemtime_from_unix(168520000 + 1000),
);
@@ -204,10 +226,14 @@ mod tests {
#[test]
fn invalid_username_format_fails() {
let message_integrity = message_integrity(RELAY_SECRET_2, 1685200000, "n23JJ2wKKtt30oXi");
let message_integrity = message_integrity(
&RELAY_SECRET_2.parse().unwrap(),
1685200000,
"n23JJ2wKKtt30oXi",
);
let result = message_integrity.verify(
RELAY_SECRET_1,
&RELAY_SECRET_1.parse().unwrap(),
"foobar",
systemtime_from_unix(168520000 + 1000),
);
@@ -244,7 +270,7 @@ mod tests {
}
fn message_integrity(
relay_secret: &str,
relay_secret: &SecretString,
username_expiry: u64,
username_salt: &str,
) -> MessageIntegrity {

View File

@@ -4,13 +4,14 @@ use futures::channel::mpsc;
use futures::{future, FutureExt, SinkExt, StreamExt};
use opentelemetry::{sdk, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use phoenix_channel::{Error, Event, PhoenixChannel};
use phoenix_channel::{Error, Event, PhoenixChannel, SecureUrl};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use relay::{
AddressFamily, Allocation, AllocationId, Command, IpStack, Server, Sleep, SocketAddrExt,
UdpSocket,
};
use secrecy::{Secret, SecretString};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::convert::Infallible;
@@ -51,7 +52,7 @@ struct Args {
///
/// If omitted, we won't connect to the portal on startup.
#[arg(long, env)]
portal_token: Option<String>,
portal_token: Option<SecretString>,
/// A seed to use for all randomness operations.
///
/// Only available in debug builds.
@@ -106,12 +107,12 @@ async fn main() -> Result<()> {
);
let channel = if let Some(token) = args.portal_token.as_ref() {
let url = args.portal_ws_url.clone();
let stamp_secret = server.auth_secret().to_string();
let base_url = args.portal_ws_url.clone();
let stamp_secret = server.auth_secret();
let span = tracing::error_span!("connect_to_portal", config_url = %url);
let span = tracing::error_span!("connect_to_portal", config_url = %base_url);
connect_to_portal(&args, token, url, stamp_secret)
connect_to_portal(&args, token, base_url, stamp_secret)
.instrument(span)
.await?
} else {
@@ -237,16 +238,19 @@ fn env_filter() -> EnvFilter {
async fn connect_to_portal(
args: &Args,
token: &str,
token: &SecretString,
mut url: Url,
stamp_secret: String,
stamp_secret: &SecretString,
) -> Result<Option<PhoenixChannel<InboundPortalMessage, ()>>> {
use secrecy::ExposeSecret;
if !url.path().is_empty() {
tracing::warn!("Overwriting path component of portal URL with '/relay/websocket'");
}
url.set_path("relay/websocket");
url.query_pairs_mut().append_pair("token", token);
url.query_pairs_mut()
.append_pair("token", token.expose_secret().as_str());
if let Some(public_ip4_addr) = args.public_ip4_addr {
url.query_pairs_mut()
@@ -258,7 +262,7 @@ async fn connect_to_portal(
}
let mut channel = PhoenixChannel::<InboundPortalMessage, ()>::connect(
url,
Secret::from(SecureUrl::from_url(url)),
format!("relay/{}", env!("CARGO_PKG_VERSION")),
)
.await
@@ -266,7 +270,12 @@ async fn connect_to_portal(
tracing::info!("Connected to portal, waiting for init message",);
channel.join("relay", JoinMessage { stamp_secret });
channel.join(
"relay",
JoinMessage {
stamp_secret: stamp_secret.expose_secret().to_string(),
},
);
loop {
match future::poll_fn(|cx| channel.poll(cx))

View File

@@ -15,6 +15,7 @@ use core::fmt;
use opentelemetry::metrics::{Counter, Unit, UpDownCounter};
use opentelemetry::KeyValue;
use rand::Rng;
use secrecy::SecretString;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::net::{IpAddr, SocketAddr};
@@ -66,7 +67,7 @@ pub struct Server<R> {
rng: R,
auth_secret: String,
auth_secret: SecretString,
nonces: Nonces,
@@ -187,7 +188,7 @@ where
channel_numbers_by_peer: Default::default(),
pending_commands: Default::default(),
next_allocation_id: AllocationId(1),
auth_secret: hex::encode(rng.gen::<[u8; 32]>()),
auth_secret: SecretString::from(hex::encode(rng.gen::<[u8; 32]>())),
rng,
time_events: TimeEvents::default(),
nonces: Default::default(),
@@ -197,7 +198,7 @@ where
}
}
pub fn auth_secret(&self) -> &str {
pub fn auth_secret(&self) -> &SecretString {
&self.auth_secret
}
@@ -326,7 +327,7 @@ where
let Some(channel) = self.channels_by_number.get(channel_number) else {
debug_assert!(false, "unknown channel {}", channel_number);
return
return;
};
if !channel.bound {

View File

@@ -3,6 +3,7 @@ use crate::server::channel_data::ChannelData;
use crate::server::UDP_TRANSPORT;
use crate::Attribute;
use bytecodec::DecodeExt;
use secrecy::SecretString;
use std::io;
use std::time::Duration;
use stun_codec::rfc5389::attributes::{ErrorCode, MessageIntegrity, Nonce, Username};
@@ -146,7 +147,7 @@ impl Allocate {
transaction_id: TransactionId,
lifetime: Option<Lifetime>,
username: Username,
relay_secret: &str,
relay_secret: &SecretString,
nonce: Uuid,
) -> Self {
let (requested_transport, nonce, message_integrity) = Self::make_attributes(
@@ -174,7 +175,7 @@ impl Allocate {
transaction_id: TransactionId,
lifetime: Option<Lifetime>,
username: Username,
relay_secret: &str,
relay_secret: &SecretString,
nonce: Uuid,
) -> Self {
let requested_address_family = RequestedAddressFamily::new(AddressFamily::V6);
@@ -230,7 +231,7 @@ impl Allocate {
transaction_id: TransactionId,
lifetime: &Option<Lifetime>,
username: &Username,
relay_secret: &str,
relay_secret: &SecretString,
nonce: Uuid,
requested_address_family: Option<RequestedAddressFamily>,
) -> (RequestedTransport, Nonce, MessageIntegrity) {
@@ -333,7 +334,7 @@ impl Refresh {
transaction_id: TransactionId,
lifetime: Option<Lifetime>,
username: Username,
relay_secret: &str,
relay_secret: &SecretString,
nonce: Uuid,
) -> Self {
let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128");
@@ -416,7 +417,7 @@ impl ChannelBind {
channel_number: ChannelNumber,
xor_peer_address: XorPeerAddress,
username: Username,
relay_secret: &str,
relay_secret: &SecretString,
nonce: Uuid,
) -> Self {
let nonce = Nonce::new(nonce.as_hyphenated().to_string()).expect("len(uuid) < 128");

View File

@@ -4,6 +4,7 @@ use relay::{
AddressFamily, Allocate, AllocationId, Attribute, Binding, ChannelBind, ChannelData,
ClientMessage, Command, IpStack, Refresh, Server,
};
use secrecy::SecretString;
use std::collections::HashMap;
use std::iter;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4};
@@ -437,7 +438,7 @@ impl TestServer {
self
}
fn auth_secret(&self) -> &str {
fn auth_secret(&self) -> &SecretString {
self.server.auth_secret()
}