diff --git a/rust/relay/src/auth.rs b/rust/relay/src/auth.rs index ba9aa47f4..3b13acede 100644 --- a/rust/relay/src/auth.rs +++ b/rust/relay/src/auth.rs @@ -180,7 +180,7 @@ impl Nonces { /// Record the usage of a nonce in a request. pub(crate) fn handle_nonce_used(&mut self, nonce: Uuid) -> Result<(), Error> { let mut entry = match self.inner.entry(nonce) { - Entry::Vacant(_) => return Err(Error::InvalidNonce), + Entry::Vacant(_) => return Err(Error::UnknownNonce), Entry::Occupied(entry) => entry, }; @@ -189,7 +189,7 @@ impl Nonces { if *remaining_requests == 0 { entry.remove(); - return Err(Error::InvalidNonce); + return Err(Error::NonceUsedUp); } *remaining_requests -= 1; @@ -206,8 +206,10 @@ pub(crate) enum Error { InvalidPassword, #[error("invalid username")] InvalidUsername, - #[error("invalid nonce")] - InvalidNonce, + #[error("nonce has been used up")] + NonceUsedUp, + #[error("unknown nonce")] + UnknownNonce, #[error("cannot authenticate message")] CannotAuthenticate(#[from] bytecodec::Error), } @@ -365,7 +367,7 @@ mod tests { assert!(matches!( nonces.handle_nonce_used(nonce).unwrap_err(), - Error::InvalidNonce + Error::NonceUsedUp )); } @@ -376,7 +378,7 @@ mod tests { assert!(matches!( nonces.handle_nonce_used(nonce).unwrap_err(), - Error::InvalidNonce + Error::UnknownNonce )); } diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index eda83e9ae..04d41525c 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -315,10 +315,23 @@ where } }; - let Err(error_response) = result else { + let Err(mut error_response) = result else { return None; }; + let is_auth_error = error_response + .get_attribute::() + .is_some_and(|error_code| { + error_code == &ErrorCode::from(Unauthorized) + || error_code == &ErrorCode::from(StaleNonce) + }); + + // In case of a 401 or 438 response, attach a realm and nonce. + if is_auth_error { + error_response.add_attribute((*FIREZONE).clone()); + error_response.add_attribute(self.new_nonce_attribute()); + } + let message = match username { Some(username) => { match AuthenticatedMessage::new(&self.auth_secret, username.name(), error_response) @@ -467,31 +480,29 @@ where if let Some(allocation) = self.allocations.get(&sender) { Span::current().record("allocation", display(&allocation.port)); - tracing::warn!(target: "relay", "Client already has an allocation"); + let (error_response, msg) = make_error_response(AllocationMismatch, &request); - return Err(self.make_error_response( - AllocationMismatch, - &request, - ResponseErrorLevel::Warn, - )); + tracing::warn!(target: "relay", "{msg}: Client already has an allocation"); + + return Err(error_response); } let max_available_ports = self.max_available_ports() as usize; if self.clients_by_allocation.len() == max_available_ports { - tracing::warn!(target: "relay", %max_available_ports, "No more ports available"); + let (error_response, msg) = make_error_response(InsufficientCapacity, &request); - return Err(self.make_error_response( - InsufficientCapacity, - &request, - ResponseErrorLevel::Warn, - )); + tracing::warn!(target: "relay", %max_available_ports, "{msg}: No more ports available"); + + return Err(error_response); } let requested_protocol = request.requested_transport().protocol(); if requested_protocol != UDP_TRANSPORT { - tracing::warn!(target: "relay", %requested_protocol, "Unsupported protocol"); + let (error_response, msg) = make_error_response(BadRequest, &request); - return Err(self.make_error_response(BadRequest, &request, ResponseErrorLevel::Warn)); + tracing::warn!(target: "relay", %requested_protocol, "{msg}: Unsupported protocol"); + + return Err(error_response); } let (first_relay_address, maybe_second_relay_addr) = derive_relay_addresses( @@ -499,7 +510,12 @@ where request.requested_address_family(), request.additional_address_family(), ) - .map_err(|e| self.make_error_response(e, &request, ResponseErrorLevel::Warn))?; + .map_err(|e| { + let (error_response, msg) = make_error_response(e, &request); + tracing::warn!(target: "relay", "{msg}: Failed to derive relay addresses"); + + error_response + })?; // TODO: Do we need to handle DONT-FRAGMENT? // TODO: Do we need to handle EVEN/ODD-PORT? @@ -586,11 +602,10 @@ where // TODO: Verify that this is the correct error code. let Some(allocation) = self.allocations.get_mut(&sender) else { - return Err(self.make_error_response( - AllocationMismatch, - &request, - ResponseErrorLevel::Warn, - )); + let (error_response, msg) = make_error_response(AllocationMismatch, &request); + tracing::warn!(target: "relay", "{msg}: Sender doesn't have an allocation"); + + return Err(error_response); }; Span::current().record("allocation", display(&allocation.port)); @@ -638,11 +653,11 @@ where let username = self.verify_auth(&request)?; let Some(allocation) = self.allocations.get_mut(&sender) else { - return Err(self.make_error_response( - AllocationMismatch, - &request, - ResponseErrorLevel::Warn, - )); + let (error_response, msg) = make_error_response(AllocationMismatch, &request); + + tracing::warn!(target: "relay", "{msg}: Sender doesn't have an allocation"); + + return Err(error_response); }; // Note: `channel_number` is enforced to be in the correct range. @@ -655,13 +670,11 @@ where // Check that our allocation can handle the requested peer addr. if !allocation.can_relay_to(peer_address) { - tracing::warn!(target: "relay", "Allocation cannot relay to peer"); + let (error_response, msg) = make_error_response(PeerAddressFamilyMismatch, &request); - return Err(self.make_error_response( - PeerAddressFamilyMismatch, - &request, - ResponseErrorLevel::Warn, - )); + tracing::warn!(target: "relay", "{msg}: Allocation cannot relay to peer"); + + return Err(error_response); } // Ensure the same address isn't already bound to a different channel. @@ -670,13 +683,11 @@ where .get(&(sender, peer_address)) { if number != &requested_channel { - tracing::warn!(target: "relay", existing_channel = %number.value(), "Peer is already bound to another channel"); + let (error_response, msg) = make_error_response(BadRequest, &request); - return Err(self.make_error_response( - BadRequest, - &request, - ResponseErrorLevel::Warn, - )); + tracing::warn!(target: "relay", existing_channel = %number.value(), "{msg}: Peer is already bound to another channel"); + + return Err(error_response); } } @@ -686,13 +697,11 @@ where .get_mut(&(sender, requested_channel)) { if channel.peer_address != peer_address { - tracing::warn!(target: "relay", existing_peer = %channel.peer_address, "Channel is already bound to a different peer"); + let (error_response, msg) = make_error_response(BadRequest, &request); - return Err(self.make_error_response( - BadRequest, - &request, - ResponseErrorLevel::Warn, - )); + tracing::warn!(target: "relay", existing_peer = %channel.peer_address, "{msg}: Channel is already bound to a different peer"); + + return Err(error_response); } // Binding requests for existing channels act as a refresh for the binding. @@ -802,32 +811,48 @@ where request: &(impl StunRequest + ProtectedRequest), ) -> Result> { let message_integrity = request.message_integrity().ok_or_else(|| { - self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) + let (error_response, msg) = make_error_response(Unauthorized, request); + tracing::warn!(target: "relay", "{msg}: Missing `MessageIntegrity` attribute"); + + error_response })?; let username = request.username().ok_or_else(|| { - self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) + let (error_response, msg) = make_error_response(Unauthorized, request); + tracing::warn!(target: "relay", "{msg}: Missing `Username` attribute"); + + error_response })?; let nonce = request .nonce() .ok_or_else(|| { - self.make_error_response(Unauthorized, request, ResponseErrorLevel::Debug) + let (error_response, msg) = make_error_response(Unauthorized, request); + tracing::debug!(target: "relay", "{msg}: Missing `Nonce` attribute"); + + error_response })? .value() .parse::() .map_err(|e| { - tracing::debug!(target: "relay", error = std_dyn_err(&e), "failed to parse nonce"); + let (error_response, msg) = make_error_response(Unauthorized, request); + tracing::warn!(target: "relay", "{msg}: Failed to parse nonce: {e}"); - self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) + error_response })?; - self.nonces.handle_nonce_used(nonce).map_err(|_| { - self.make_error_response(StaleNonce, request, ResponseErrorLevel::Debug) + self.nonces.handle_nonce_used(nonce).map_err(|e| { + let (error_response, msg) = make_error_response(StaleNonce, request); + tracing::debug!(target: "relay", "{msg}: Nonce is invalid: {e}"); + + error_response })?; message_integrity .verify(&self.auth_secret, username.name(), SystemTime::now()) // This is impure but we don't need to control this in our tests. - .map_err(|_| { - self.make_error_response(Unauthorized, request, ResponseErrorLevel::Warn) + .map_err(|e| { + let (error_response, msg) = make_error_response(Unauthorized, request); + tracing::warn!(target: "relay", "{msg}: MessageIntegrity check failed: {e}"); + + error_response })?; Ok(username.clone()) @@ -913,9 +938,8 @@ where ) { Ok(message) => message, Err(e) => { - tracing::warn!(target: "relay", error = std_dyn_err(&e), "Failed to authenticate message"); - let error_response = - self.make_error_response(ServerError, request, ResponseErrorLevel::Warn); + let (error_response, msg) = make_error_response(ServerError, request); + tracing::warn!(target: "relay", error = std_dyn_err(&e), "{msg}: Failed to authenticate message"); AuthenticatedMessage::new_dangerous_unauthenticated(error_response) } @@ -1048,55 +1072,34 @@ where tracing::info!(target: "relay", channel = %chan.value(), %client, %peer, %allocation, "Channel binding is now deleted (and can be rebound)"); } - fn make_error_response( - &mut self, - error_code: impl Into, - request: &impl StunRequest, - error_level: ResponseErrorLevel, - ) -> Message { - let error_code = error_code.into(); + fn new_nonce_attribute(&mut self) -> Nonce { + let new_nonce = Uuid::from_u128(self.rng.gen()); - match error_level { - ResponseErrorLevel::Warn => { - tracing::warn!(target: "relay", "{} failed: {}", request.method(), error_code.reason_phrase()); - } - ResponseErrorLevel::Debug => { - tracing::debug!(target: "relay", "{} failed: {}", request.method(), error_code.reason_phrase()); - } - } + self.add_nonce(new_nonce); - let mut message = Message::new( - MessageClass::ErrorResponse, - request.method(), - request.transaction_id(), - ); - - let is_auth_error = error_code == ErrorCode::from(Unauthorized) - || error_code == ErrorCode::from(StaleNonce); - - message.add_attribute(Attribute::from(error_code)); - - // In case of a 401 or 438 response, attach a realm and nonce. - if is_auth_error { - let new_nonce = Uuid::from_u128(self.rng.gen()); - - self.add_nonce(new_nonce); - - message.add_attribute( - Nonce::new(new_nonce.to_string()).expect( - "UUIDs are valid nonces because they are less than 128 characters long", - ), - ); - message.add_attribute((*FIREZONE).clone()); - } - - message + Nonce::new(new_nonce.to_string()) + .expect("UUIDs are valid nonces because they are less than 128 characters long") } } -enum ResponseErrorLevel { - Warn, - Debug, +fn make_error_response( + error_code: impl Into, + request: &impl StunRequest, +) -> (Message, String) { + let method = request.method(); + + let mut message = Message::new( + MessageClass::ErrorResponse, + method, + request.transaction_id(), + ); + let attribute = error_code.into(); + let reason = attribute.reason_phrase(); + let msg = format!("{method} failed with {reason}"); + + message.add_attribute(attribute); + + (message, msg) } fn refresh_success_response( diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index 060fb9cec..3802249b4 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -866,8 +866,8 @@ fn unauthorized_allocate_response( let mut message = Message::::new(MessageClass::ErrorResponse, ALLOCATE, transaction_id); message.add_attribute(ErrorCode::from(Unauthorized)); - message.add_attribute(Nonce::new(nonce.as_hyphenated().to_string()).unwrap()); message.add_attribute(Realm::new("firezone".to_owned()).unwrap()); + message.add_attribute(Nonce::new(nonce.as_hyphenated().to_string()).unwrap()); message }