feat(connlib): validate integrity of all relay responses (#7378)

In order to avoid processing of responses of relays that somehow got
altered on the network path, we now use the client's `password` as a
shared secret for the relay to also authenticate its responses. This
means that not all message can be authenticated. In particular, BINDING
requests will still be unauthenticated.

Performing this validation now requires every component that crafts
input to the `Allocation` to include a valid `MessageIntegrity`
attribute. This is somewhat problematic for the regression tests of the
relay and the unit tests of `Allocation`. In both cases, we implement
workarounds so we don't have to actually compute a valid
`MessageIntegrity`. This is deemed acceptable because:

- Both of these are just tests.
- We do test the validation path using `tunnel_test` because there we
run an actual relay.
This commit is contained in:
Thomas Eizinger
2024-11-19 18:32:33 +00:00
committed by GitHub
parent ecec00afed
commit 56db250e2c
5 changed files with 216 additions and 31 deletions

View File

@@ -303,6 +303,18 @@ impl Allocation {
Span::current().record("method", field::display(message.method()));
Span::current().record("class", field::display(message.class()));
// Early return to avoid cryptographic work in case it isn't our message.
if !self.sent_requests.contains_key(&transaction_id) {
return false;
}
let passed_message_integrity_check = self.check_message_integrity(&message);
if message.method() != BINDING && !passed_message_integrity_check {
tracing::warn!("Message integrity check failed");
return true; // The message still indicated that it was for this `Allocation`.
}
let Some((original_dst, original_request, sent_at, _, _)) =
self.sent_requests.remove(&transaction_id)
else {
@@ -1041,6 +1053,31 @@ impl Allocation {
backoff.clock.now = now;
}
}
#[cfg(test)]
fn check_message_integrity(&self, _: &Message<Attribute>) -> bool {
true // In order to make the tests simpler, we skip the message integrity check there.
}
#[cfg(not(test))]
fn check_message_integrity(&self, message: &Message<Attribute>) -> bool {
message
.get_attribute::<MessageIntegrity>()
.is_some_and(|mi| {
let Some(credentials) = &self.credentials else {
tracing::debug!("Cannot check message integrity without credentials");
return false;
};
mi.check_long_term_credential(
&credentials.username,
&credentials.realm,
&credentials.password,
)
.is_ok()
})
}
}
fn authenticate(message: Message<Attribute>, credentials: &Credentials) -> Message<Attribute> {

View File

@@ -1,5 +1,6 @@
use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use bytecodec::Encode;
use once_cell::sync::Lazy;
use secrecy::{ExposeSecret, SecretString};
use sha2::digest::FixedOutput;
@@ -8,8 +9,11 @@ use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use stun_codec::rfc5389::attributes::{MessageIntegrity, Realm, Username};
use stun_codec::Message;
use uuid::Uuid;
use crate::Attribute;
// TODO: Upstream a const constructor to `stun-codec`.
pub static FIREZONE: Lazy<Realm> =
Lazy::new(|| Realm::new("firezone".to_owned()).expect("static realm is less than 128 chars"));
@@ -51,6 +55,72 @@ impl MessageIntegrityExt for MessageIntegrity {
}
}
pub(crate) struct AuthenticatedMessage(Message<Attribute>);
impl AuthenticatedMessage {
/// Creates a new [`AuthenticatedMessage`] that isn't actually authenticated.
///
/// This should only be used in circumstances where we cannot authenticate the message because e.g. the original request wasn't authenticated either.
pub(crate) fn new_dangerous_unauthenticated(message: Message<Attribute>) -> Self {
Self(message)
}
pub(crate) fn new(
relay_secret: &SecretString,
username: &str,
mut message: Message<Attribute>,
) -> Result<Self, Error> {
let (expiry_unix_timestamp, salt) = split_username(username)?;
let expired = systemtime_from_unix(expiry_unix_timestamp);
let username = Username::new(format!("{}:{}", expiry_unix_timestamp, salt))
.map_err(|_| Error::InvalidUsername)?;
let password = generate_password(relay_secret, expired, salt);
let message_integrity =
MessageIntegrity::new_long_term_credential(&message, &username, &FIREZONE, &password)?;
message.add_attribute(message_integrity);
Ok(Self(message))
}
pub fn class(&self) -> stun_codec::MessageClass {
self.0.class()
}
pub fn method(&self) -> stun_codec::Method {
self.0.method()
}
pub fn get_attribute<T>(&self) -> Option<&T>
where
T: stun_codec::Attribute,
Attribute: stun_codec::convert::TryAsRef<T>,
{
self.0.get_attribute()
}
}
#[derive(Debug, Default)]
pub(crate) struct MessageEncoder(stun_codec::MessageEncoder<Attribute>);
impl Encode for MessageEncoder {
type Item = AuthenticatedMessage;
fn encode(&mut self, buf: &mut [u8], eos: bytecodec::Eos) -> bytecodec::Result<usize> {
self.0.encode(buf, eos)
}
fn start_encoding(&mut self, item: Self::Item) -> bytecodec::Result<()> {
self.0.start_encoding(item.0)
}
fn requiring_bytes(&self) -> bytecodec::ByteCount {
self.0.requiring_bytes()
}
}
/// Tracks valid nonces for the TURN relay.
///
/// The semantic nature of nonces is an implementation detail of the relay in TURN.
@@ -92,7 +162,7 @@ impl Nonces {
}
}
#[derive(Debug, PartialEq, thiserror::Error)]
#[derive(Debug, thiserror::Error)]
pub(crate) enum Error {
#[error("expired")]
Expired,
@@ -102,6 +172,8 @@ pub(crate) enum Error {
InvalidUsername,
#[error("invalid nonce")]
InvalidNonce,
#[error("cannot authenticate message")]
CannotAuthenticate(#[from] bytecodec::Error),
}
pub(crate) fn split_username(username: &str) -> Result<(u64, &str), Error> {
@@ -207,7 +279,7 @@ mod tests {
systemtime_from_unix(1685200000),
);
assert_eq!(result.unwrap_err(), Error::Expired)
assert!(matches!(result.unwrap_err(), Error::Expired))
}
#[test]
@@ -224,7 +296,7 @@ mod tests {
systemtime_from_unix(168520000 + 1000),
);
assert_eq!(result.unwrap_err(), Error::InvalidPassword)
assert!(matches!(result.unwrap_err(), Error::InvalidPassword))
}
#[test]
@@ -241,7 +313,7 @@ mod tests {
systemtime_from_unix(168520000 + 1000),
);
assert_eq!(result.unwrap_err(), Error::InvalidUsername)
assert!(matches!(result.unwrap_err(), Error::InvalidUsername))
}
#[test]
@@ -255,10 +327,10 @@ mod tests {
nonces.handle_nonce_used(nonce).unwrap();
}
assert_eq!(
assert!(matches!(
nonces.handle_nonce_used(nonce).unwrap_err(),
Error::InvalidNonce
);
));
}
#[test]
@@ -266,10 +338,10 @@ mod tests {
let mut nonces = Nonces::default();
let nonce = Uuid::new_v4();
assert_eq!(
assert!(matches!(
nonces.handle_nonce_used(nonce).unwrap_err(),
Error::InvalidNonce
);
));
}
fn message_integrity(

View File

@@ -6,7 +6,7 @@ pub use crate::server::client_message::{
Allocate, Binding, ChannelBind, ClientMessage, CreatePermission, Refresh,
};
use crate::auth::{MessageIntegrityExt, Nonces, FIREZONE};
use crate::auth::{self, AuthenticatedMessage, MessageIntegrityExt, Nonces, FIREZONE};
use crate::net_ext::IpAddrExt;
use crate::{ClientSocket, IpStack, PeerSocket};
use anyhow::Result;
@@ -27,7 +27,7 @@ use std::time::{Duration, Instant, SystemTime};
use stun_codec::rfc5389::attributes::{
ErrorCode, MessageIntegrity, Nonce, Realm, Software, Username, XorMappedAddress,
};
use stun_codec::rfc5389::errors::{BadRequest, StaleNonce, Unauthorized};
use stun_codec::rfc5389::errors::{BadRequest, ServerError, StaleNonce, Unauthorized};
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::rfc5766::attributes::{
ChannelNumber, Lifetime, RequestedTransport, XorPeerAddress, XorRelayAddress,
@@ -38,7 +38,7 @@ use stun_codec::rfc8656::attributes::{
AdditionalAddressFamily, AddressFamily, RequestedAddressFamily,
};
use stun_codec::rfc8656::errors::{AddressFamilyNotSupported, PeerAddressFamilyMismatch};
use stun_codec::{Message, MessageClass, MessageEncoder, Method, TransactionId};
use stun_codec::{Message, MessageClass, Method, TransactionId};
use tracing::{field, Span};
use tracing_core::field::display;
use uuid::Uuid;
@@ -53,7 +53,7 @@ use uuid::Uuid;
#[derive(Debug)]
pub struct Server<R> {
decoder: client_message::Decoder,
encoder: MessageEncoder<Attribute>,
encoder: auth::MessageEncoder,
public_address: IpStack,
@@ -266,7 +266,10 @@ where
Ok(Err(error_response)) => {
tracing::warn!(target: "relay", %sender, method = %error_response.method(), "Failed to decode message");
self.send_message(error_response, sender);
// This is fine, the original message failed to parse to we cannot respond with an authenticated reply.
let message = AuthenticatedMessage::new_dangerous_unauthenticated(error_response);
self.send_message(message, sender);
}
// Parsing the bytes failed.
Err(client_message::Error::BadChannelData(ref error)) => {
@@ -292,6 +295,8 @@ where
sender: ClientSocket,
now: Instant,
) -> Option<(AllocationPort, PeerSocket)> {
let username = message.username().cloned();
let result = match message {
ClientMessage::Allocate(request) => self.handle_allocate_request(request, sender, now),
ClientMessage::Refresh(request) => self.handle_refresh_request(request, sender, now),
@@ -314,7 +319,21 @@ where
return None;
};
self.send_message(error_response, sender);
let message = match username {
Some(username) => {
match AuthenticatedMessage::new(&self.auth_secret, username.name(), error_response)
{
Ok(message) => message,
Err(e) => {
tracing::warn!(target: "relay", error = std_dyn_err(&e), "Failed to create error response");
return None;
}
}
}
None => AuthenticatedMessage::new_dangerous_unauthenticated(error_response), // We don't have a username so we can't authenticate the response.
};
self.send_message(message, sender);
None
}
@@ -428,7 +447,10 @@ where
tracing::info!("Handled BINDING request");
self.send_message(message, sender);
self.send_message(
AuthenticatedMessage::new_dangerous_unauthenticated(message),
sender,
);
}
/// Handle a TURN allocate request.
@@ -441,7 +463,7 @@ where
sender: ClientSocket,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request)?;
let username = self.verify_auth(&request)?;
if let Some(allocation) = self.allocations.get(&sender) {
Span::current().record("allocation", display(&allocation.port));
@@ -522,7 +544,7 @@ where
family: second_relay_addr.family(),
});
}
self.send_message(message, sender);
self.authenticate_and_send(username.name(), &request, message, sender);
Span::current().record("allocation", display(&allocation.port));
@@ -560,7 +582,7 @@ where
sender: ClientSocket,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request)?;
let username = self.verify_auth(&request)?;
// TODO: Verify that this is the correct error code.
let Some(allocation) = self.allocations.get_mut(&sender) else {
@@ -579,7 +601,9 @@ where
let port = allocation.port;
self.delete_allocation(port);
self.send_message(
self.authenticate_and_send(
username.name(),
&request,
refresh_success_response(effective_lifetime, request.transaction_id()),
sender,
);
@@ -591,7 +615,9 @@ where
tracing::info!(target: "relay", "Refreshed allocation");
self.send_message(
self.authenticate_and_send(
username.name(),
&request,
refresh_success_response(effective_lifetime, request.transaction_id()),
sender,
);
@@ -609,7 +635,7 @@ where
sender: ClientSocket,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request)?;
let username = self.verify_auth(&request)?;
let Some(allocation) = self.allocations.get_mut(&sender) else {
return Err(self.make_error_response(
@@ -681,7 +707,9 @@ where
tracing::info!(target: "relay", "Refreshed channel binding");
self.send_message(
self.authenticate_and_send(
username.name(),
&request,
channel_bind_success_response(request.transaction_id()),
sender,
);
@@ -696,7 +724,9 @@ where
let port = allocation.port;
self.create_channel_binding(sender, requested_channel, peer_address, port, now);
self.send_message(
self.authenticate_and_send(
username.name(),
&request,
channel_bind_success_response(request.transaction_id()),
sender,
);
@@ -718,9 +748,11 @@ where
request: CreatePermission,
sender: ClientSocket,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request)?;
let username = self.verify_auth(&request)?;
self.send_message(
self.authenticate_and_send(
username.name(),
&request,
create_permission_success_response(request.transaction_id()),
sender,
);
@@ -768,7 +800,7 @@ where
fn verify_auth(
&mut self,
request: &(impl StunRequest + ProtectedRequest),
) -> Result<(), Message<Attribute>> {
) -> Result<Username, Message<Attribute>> {
let message_integrity = request.message_integrity().ok_or_else(|| {
self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn)
})?;
@@ -798,7 +830,7 @@ where
self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn)
})?;
Ok(())
Ok(username.clone())
}
fn create_new_allocation(
@@ -867,7 +899,32 @@ where
debug_assert!(existing.is_none());
}
fn send_message(&mut self, message: Message<Attribute>, recipient: ClientSocket) {
fn authenticate_and_send(
&mut self,
username: &str,
request: &impl StunRequest,
message: Message<Attribute>,
recipient: ClientSocket,
) {
let authenticated_message = match AuthenticatedMessage::new(
&self.auth_secret,
username,
message,
) {
Ok(message) => message,
Err(e) => {
tracing::warn!(target: "relay", error = std_dyn_err(&e), "Failed to authenticate message");
let error_response =
self.make_error_response(ServerError, request, ResponseErrorLevel::Warn);
AuthenticatedMessage::new_dangerous_unauthenticated(error_response)
}
};
self.send_message(authenticated_message, recipient);
}
fn send_message(&mut self, message: AuthenticatedMessage, recipient: ClientSocket) {
let method = message.method();
let class = message.class();
let error_code = message.get_attribute::<ErrorCode>().map(|e| e.code());

View File

@@ -109,6 +109,16 @@ impl ClientMessage<'_> {
ClientMessage::ChannelData(_) => None,
}
}
pub fn username(&self) -> Option<&Username> {
match self {
ClientMessage::ChannelData(_) | ClientMessage::Binding(_) => None,
ClientMessage::Allocate(request) => request.username(),
ClientMessage::Refresh(request) => request.username(),
ClientMessage::ChannelBind(request) => request.username(),
ClientMessage::CreatePermission(request) => request.username(),
}
}
}
#[derive(Debug)]

View File

@@ -10,7 +10,9 @@ use secrecy::SecretString;
use std::iter;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::{Duration, Instant, SystemTime};
use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress};
use stun_codec::rfc5389::attributes::{
ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress,
};
use stun_codec::rfc5389::errors::Unauthorized;
use stun_codec::rfc5389::methods::BINDING;
use stun_codec::rfc5766::attributes::{ChannelNumber, Lifetime, XorPeerAddress, XorRelayAddress};
@@ -763,16 +765,23 @@ impl TestServer {
match (expected_output, actual_output) {
(
Output::SendMessage((to, message)),
Output::SendMessage((to, mut message)),
Command::SendMessage { payload, recipient },
) => {
let sent_message = parse_message(&payload);
// In order to avoid simulating authentication, we copy the MessageIntegrity attribute.
if let Some(mi) = sent_message.get_attribute::<MessageIntegrity>() {
message.add_attribute(mi.clone());
}
let expected_bytes = MessageEncoder::new()
.encode_into_bytes(message.clone())
.unwrap();
if expected_bytes != payload {
let expected_message = format!("{:?}", message);
let actual_message = format!("{:?}", parse_message(&payload));
let actual_message = format!("{:?}", sent_message);
difference::assert_diff!(&expected_message, &actual_message, "\n", 0);
}