diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index ead557690..f9d7f511c 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -303,6 +303,18 @@ impl Allocation { Span::current().record("method", field::display(message.method())); Span::current().record("class", field::display(message.class())); + // Early return to avoid cryptographic work in case it isn't our message. + if !self.sent_requests.contains_key(&transaction_id) { + return false; + } + + let passed_message_integrity_check = self.check_message_integrity(&message); + + if message.method() != BINDING && !passed_message_integrity_check { + tracing::warn!("Message integrity check failed"); + return true; // The message still indicated that it was for this `Allocation`. + } + let Some((original_dst, original_request, sent_at, _, _)) = self.sent_requests.remove(&transaction_id) else { @@ -1041,6 +1053,31 @@ impl Allocation { backoff.clock.now = now; } } + + #[cfg(test)] + fn check_message_integrity(&self, _: &Message) -> bool { + true // In order to make the tests simpler, we skip the message integrity check there. + } + + #[cfg(not(test))] + fn check_message_integrity(&self, message: &Message) -> bool { + message + .get_attribute::() + .is_some_and(|mi| { + let Some(credentials) = &self.credentials else { + tracing::debug!("Cannot check message integrity without credentials"); + + return false; + }; + + mi.check_long_term_credential( + &credentials.username, + &credentials.realm, + &credentials.password, + ) + .is_ok() + }) + } } fn authenticate(message: Message, credentials: &Credentials) -> Message { diff --git a/rust/relay/src/auth.rs b/rust/relay/src/auth.rs index a775ef2e4..04ed9aa98 100644 --- a/rust/relay/src/auth.rs +++ b/rust/relay/src/auth.rs @@ -1,5 +1,6 @@ use base64::prelude::BASE64_STANDARD_NO_PAD; use base64::Engine; +use bytecodec::Encode; use once_cell::sync::Lazy; use secrecy::{ExposeSecret, SecretString}; use sha2::digest::FixedOutput; @@ -8,8 +9,11 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::time::{Duration, SystemTime}; use stun_codec::rfc5389::attributes::{MessageIntegrity, Realm, Username}; +use stun_codec::Message; use uuid::Uuid; +use crate::Attribute; + // TODO: Upstream a const constructor to `stun-codec`. pub static FIREZONE: Lazy = Lazy::new(|| Realm::new("firezone".to_owned()).expect("static realm is less than 128 chars")); @@ -51,6 +55,72 @@ impl MessageIntegrityExt for MessageIntegrity { } } +pub(crate) struct AuthenticatedMessage(Message); + +impl AuthenticatedMessage { + /// Creates a new [`AuthenticatedMessage`] that isn't actually authenticated. + /// + /// This should only be used in circumstances where we cannot authenticate the message because e.g. the original request wasn't authenticated either. + pub(crate) fn new_dangerous_unauthenticated(message: Message) -> Self { + Self(message) + } + + pub(crate) fn new( + relay_secret: &SecretString, + username: &str, + mut message: Message, + ) -> Result { + let (expiry_unix_timestamp, salt) = split_username(username)?; + let expired = systemtime_from_unix(expiry_unix_timestamp); + + let username = Username::new(format!("{}:{}", expiry_unix_timestamp, salt)) + .map_err(|_| Error::InvalidUsername)?; + let password = generate_password(relay_secret, expired, salt); + + let message_integrity = + MessageIntegrity::new_long_term_credential(&message, &username, &FIREZONE, &password)?; + + message.add_attribute(message_integrity); + + Ok(Self(message)) + } + + pub fn class(&self) -> stun_codec::MessageClass { + self.0.class() + } + + pub fn method(&self) -> stun_codec::Method { + self.0.method() + } + + pub fn get_attribute(&self) -> Option<&T> + where + T: stun_codec::Attribute, + Attribute: stun_codec::convert::TryAsRef, + { + self.0.get_attribute() + } +} + +#[derive(Debug, Default)] +pub(crate) struct MessageEncoder(stun_codec::MessageEncoder); + +impl Encode for MessageEncoder { + type Item = AuthenticatedMessage; + + fn encode(&mut self, buf: &mut [u8], eos: bytecodec::Eos) -> bytecodec::Result { + self.0.encode(buf, eos) + } + + fn start_encoding(&mut self, item: Self::Item) -> bytecodec::Result<()> { + self.0.start_encoding(item.0) + } + + fn requiring_bytes(&self) -> bytecodec::ByteCount { + self.0.requiring_bytes() + } +} + /// Tracks valid nonces for the TURN relay. /// /// The semantic nature of nonces is an implementation detail of the relay in TURN. @@ -92,7 +162,7 @@ impl Nonces { } } -#[derive(Debug, PartialEq, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { #[error("expired")] Expired, @@ -102,6 +172,8 @@ pub(crate) enum Error { InvalidUsername, #[error("invalid nonce")] InvalidNonce, + #[error("cannot authenticate message")] + CannotAuthenticate(#[from] bytecodec::Error), } pub(crate) fn split_username(username: &str) -> Result<(u64, &str), Error> { @@ -207,7 +279,7 @@ mod tests { systemtime_from_unix(1685200000), ); - assert_eq!(result.unwrap_err(), Error::Expired) + assert!(matches!(result.unwrap_err(), Error::Expired)) } #[test] @@ -224,7 +296,7 @@ mod tests { systemtime_from_unix(168520000 + 1000), ); - assert_eq!(result.unwrap_err(), Error::InvalidPassword) + assert!(matches!(result.unwrap_err(), Error::InvalidPassword)) } #[test] @@ -241,7 +313,7 @@ mod tests { systemtime_from_unix(168520000 + 1000), ); - assert_eq!(result.unwrap_err(), Error::InvalidUsername) + assert!(matches!(result.unwrap_err(), Error::InvalidUsername)) } #[test] @@ -255,10 +327,10 @@ mod tests { nonces.handle_nonce_used(nonce).unwrap(); } - assert_eq!( + assert!(matches!( nonces.handle_nonce_used(nonce).unwrap_err(), Error::InvalidNonce - ); + )); } #[test] @@ -266,10 +338,10 @@ mod tests { let mut nonces = Nonces::default(); let nonce = Uuid::new_v4(); - assert_eq!( + assert!(matches!( nonces.handle_nonce_used(nonce).unwrap_err(), Error::InvalidNonce - ); + )); } fn message_integrity( diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index 2af67e16f..eda83e9ae 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -6,7 +6,7 @@ pub use crate::server::client_message::{ Allocate, Binding, ChannelBind, ClientMessage, CreatePermission, Refresh, }; -use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE}; +use crate::auth::{self, AuthenticatedMessage, MessageIntegrityExt, Nonces, FIREZONE}; use crate::net_ext::IpAddrExt; use crate::{ClientSocket, IpStack, PeerSocket}; use anyhow::Result; @@ -27,7 +27,7 @@ use std::time::{Duration, Instant, SystemTime}; use stun_codec::rfc5389::attributes::{ ErrorCode, MessageIntegrity, Nonce, Realm, Software, Username, XorMappedAddress, }; -use stun_codec::rfc5389::errors::{BadRequest, StaleNonce, Unauthorized}; +use stun_codec::rfc5389::errors::{BadRequest, ServerError, StaleNonce, Unauthorized}; use stun_codec::rfc5389::methods::BINDING; use stun_codec::rfc5766::attributes::{ ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress, XorRelayAddress, @@ -38,7 +38,7 @@ use stun_codec::rfc8656::attributes::{ AdditionalAddressFamily, AddressFamily, RequestedAddressFamily, }; use stun_codec::rfc8656::errors::{AddressFamilyNotSupported, PeerAddressFamilyMismatch}; -use stun_codec::{Message, MessageClass, MessageEncoder, Method, TransactionId}; +use stun_codec::{Message, MessageClass, Method, TransactionId}; use tracing::{field, Span}; use tracing_core::field::display; use uuid::Uuid; @@ -53,7 +53,7 @@ use uuid::Uuid; #[derive(Debug)] pub struct Server { decoder: client_message::Decoder, - encoder: MessageEncoder, + encoder: auth::MessageEncoder, public_address: IpStack, @@ -266,7 +266,10 @@ where Ok(Err(error_response)) => { tracing::warn!(target: "relay", %sender, method = %error_response.method(), "Failed to decode message"); - self.send_message(error_response, sender); + // This is fine, the original message failed to parse to we cannot respond with an authenticated reply. + let message = AuthenticatedMessage::new_dangerous_unauthenticated(error_response); + + self.send_message(message, sender); } // Parsing the bytes failed. Err(client_message::Error::BadChannelData(ref error)) => { @@ -292,6 +295,8 @@ where sender: ClientSocket, now: Instant, ) -> Option<(AllocationPort, PeerSocket)> { + let username = message.username().cloned(); + let result = match message { ClientMessage::Allocate(request) => self.handle_allocate_request(request, sender, now), ClientMessage::Refresh(request) => self.handle_refresh_request(request, sender, now), @@ -314,7 +319,21 @@ where return None; }; - self.send_message(error_response, sender); + let message = match username { + Some(username) => { + match AuthenticatedMessage::new(&self.auth_secret, username.name(), error_response) + { + Ok(message) => message, + Err(e) => { + tracing::warn!(target: "relay", error = std_dyn_err(&e), "Failed to create error response"); + return None; + } + } + } + None => AuthenticatedMessage::new_dangerous_unauthenticated(error_response), // We don't have a username so we can't authenticate the response. + }; + + self.send_message(message, sender); None } @@ -428,7 +447,10 @@ where tracing::info!("Handled BINDING request"); - self.send_message(message, sender); + self.send_message( + AuthenticatedMessage::new_dangerous_unauthenticated(message), + sender, + ); } /// Handle a TURN allocate request. @@ -441,7 +463,7 @@ where sender: ClientSocket, now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request)?; + let username = self.verify_auth(&request)?; if let Some(allocation) = self.allocations.get(&sender) { Span::current().record("allocation", display(&allocation.port)); @@ -522,7 +544,7 @@ where family: second_relay_addr.family(), }); } - self.send_message(message, sender); + self.authenticate_and_send(username.name(), &request, message, sender); Span::current().record("allocation", display(&allocation.port)); @@ -560,7 +582,7 @@ where sender: ClientSocket, now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request)?; + let username = self.verify_auth(&request)?; // TODO: Verify that this is the correct error code. let Some(allocation) = self.allocations.get_mut(&sender) else { @@ -579,7 +601,9 @@ where let port = allocation.port; self.delete_allocation(port); - self.send_message( + self.authenticate_and_send( + username.name(), + &request, refresh_success_response(effective_lifetime, request.transaction_id()), sender, ); @@ -591,7 +615,9 @@ where tracing::info!(target: "relay", "Refreshed allocation"); - self.send_message( + self.authenticate_and_send( + username.name(), + &request, refresh_success_response(effective_lifetime, request.transaction_id()), sender, ); @@ -609,7 +635,7 @@ where sender: ClientSocket, now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request)?; + let username = self.verify_auth(&request)?; let Some(allocation) = self.allocations.get_mut(&sender) else { return Err(self.make_error_response( @@ -681,7 +707,9 @@ where tracing::info!(target: "relay", "Refreshed channel binding"); - self.send_message( + self.authenticate_and_send( + username.name(), + &request, channel_bind_success_response(request.transaction_id()), sender, ); @@ -696,7 +724,9 @@ where let port = allocation.port; self.create_channel_binding(sender, requested_channel, peer_address, port, now); - self.send_message( + self.authenticate_and_send( + username.name(), + &request, channel_bind_success_response(request.transaction_id()), sender, ); @@ -718,9 +748,11 @@ where request: CreatePermission, sender: ClientSocket, ) -> Result<(), Message> { - self.verify_auth(&request)?; + let username = self.verify_auth(&request)?; - self.send_message( + self.authenticate_and_send( + username.name(), + &request, create_permission_success_response(request.transaction_id()), sender, ); @@ -768,7 +800,7 @@ where fn verify_auth( &mut self, request: &(impl StunRequest + ProtectedRequest), - ) -> Result<(), Message> { + ) -> Result> { let message_integrity = request.message_integrity().ok_or_else(|| { self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) })?; @@ -798,7 +830,7 @@ where self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) })?; - Ok(()) + Ok(username.clone()) } fn create_new_allocation( @@ -867,7 +899,32 @@ where debug_assert!(existing.is_none()); } - fn send_message(&mut self, message: Message, recipient: ClientSocket) { + fn authenticate_and_send( + &mut self, + username: &str, + request: &impl StunRequest, + message: Message, + recipient: ClientSocket, + ) { + let authenticated_message = match AuthenticatedMessage::new( + &self.auth_secret, + username, + message, + ) { + Ok(message) => message, + Err(e) => { + tracing::warn!(target: "relay", error = std_dyn_err(&e), "Failed to authenticate message"); + let error_response = + self.make_error_response(ServerError, request, ResponseErrorLevel::Warn); + + AuthenticatedMessage::new_dangerous_unauthenticated(error_response) + } + }; + + self.send_message(authenticated_message, recipient); + } + + fn send_message(&mut self, message: AuthenticatedMessage, recipient: ClientSocket) { let method = message.method(); let class = message.class(); let error_code = message.get_attribute::().map(|e| e.code()); diff --git a/rust/relay/src/server/client_message.rs b/rust/relay/src/server/client_message.rs index fad192e89..87cd9db59 100644 --- a/rust/relay/src/server/client_message.rs +++ b/rust/relay/src/server/client_message.rs @@ -109,6 +109,16 @@ impl ClientMessage<'_> { ClientMessage::ChannelData(_) => None, } } + + pub fn username(&self) -> Option<&Username> { + match self { + ClientMessage::ChannelData(_) | ClientMessage::Binding(_) => None, + ClientMessage::Allocate(request) => request.username(), + ClientMessage::Refresh(request) => request.username(), + ClientMessage::ChannelBind(request) => request.username(), + ClientMessage::CreatePermission(request) => request.username(), + } + } } #[derive(Debug)] diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index 83583886d..060fb9cec 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -10,7 +10,9 @@ use secrecy::SecretString; use std::iter; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::{Duration, Instant, SystemTime}; -use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress}; +use stun_codec::rfc5389::attributes::{ + ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress, +}; use stun_codec::rfc5389::errors::Unauthorized; use stun_codec::rfc5389::methods::BINDING; use stun_codec::rfc5766::attributes::{ChannelNumber, Lifetime, XorPeerAddress, XorRelayAddress}; @@ -763,16 +765,23 @@ impl TestServer { match (expected_output, actual_output) { ( - Output::SendMessage((to, message)), + Output::SendMessage((to, mut message)), Command::SendMessage { payload, recipient }, ) => { + let sent_message = parse_message(&payload); + + // In order to avoid simulating authentication, we copy the MessageIntegrity attribute. + if let Some(mi) = sent_message.get_attribute::() { + message.add_attribute(mi.clone()); + } + let expected_bytes = MessageEncoder::new() .encode_into_bytes(message.clone()) .unwrap(); if expected_bytes != payload { let expected_message = format!("{:?}", message); - let actual_message = format!("{:?}", parse_message(&payload)); + let actual_message = format!("{:?}", sent_message); difference::assert_diff!(&expected_message, &actual_message, "\n", 0); }