From 8900e263cadf458e9147106e888ef3338b4ffe9a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 9 Apr 2024 09:37:19 +1000 Subject: [PATCH] refactor(relay): favor `Instant` over `SystemTime` (#4468) This one is a bit tricky. Our auth scheme requires me to know the current time as a UNIX timestamp and that I can only get from `SystemTime` but not `Instant`. The `Server` is meant to be SANS-IO, including the current time so technically, I would have to pass that in as a parameter. I ended up settling on a compromise of making the auth verification impure and internally calling `SystemTime::now`. That results in a much nicer API and allows us to use `Instant` for everything else, e.g. expiry of channel bindings, allocations etc. Resolves: #4464. --- rust/relay/src/main.rs | 17 +++++---- rust/relay/src/proptest.rs | 12 +------ rust/relay/src/server.rs | 53 ++++++++++++++-------------- rust/relay/src/sleep.rs | 44 ++++++++++++------------ rust/relay/tests/regression.rs | 63 +++++++++++++++++++--------------- 5 files changed, 90 insertions(+), 99 deletions(-) diff --git a/rust/relay/src/main.rs b/rust/relay/src/main.rs index fa2c5e42f..5d0d6857e 100644 --- a/rust/relay/src/main.rs +++ b/rust/relay/src/main.rs @@ -16,7 +16,7 @@ use secrecy::{Secret, SecretString}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; -use std::time::{Duration, SystemTime}; +use std::time::{Duration, Instant}; use tracing::{level_filters::LevelFilter, Instrument, Subscriber}; use tracing_core::Dispatch; use tracing_stackdriver::CloudTraceConfiguration; @@ -350,8 +350,6 @@ where } fn poll(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - let now = SystemTime::now(); - loop { // Priority 1: Execute the pending commands of the server. if let Some(next_command) = self.server.next_command() { @@ -417,10 +415,11 @@ where from, packet, })) => { - if let Some((port, peer)) = - self.server - .handle_client_input(packet, ClientSocket::new(from), now) - { + if let Some((port, peer)) = self.server.handle_client_input( + packet, + ClientSocket::new(from), + Instant::now(), + ) { // Re-parse as `ChannelData` if we should relay it. let payload = ChannelData::parse(packet) .expect("valid ChannelData if we should relay it") @@ -476,8 +475,8 @@ where } // Priority 4: Handle time-sensitive tasks: - if self.sleep.poll_unpin(cx).is_ready() { - self.server.handle_timeout(now); + if let Poll::Ready(deadline) = self.sleep.poll_unpin(cx) { + self.server.handle_timeout(deadline); continue; // Handle potentially new commands. } diff --git a/rust/relay/src/proptest.rs b/rust/relay/src/proptest.rs index abe9efb12..0f0b64298 100644 --- a/rust/relay/src/proptest.rs +++ b/rust/relay/src/proptest.rs @@ -4,8 +4,7 @@ use proptest::arbitrary::any; use proptest::strategy::Just; use proptest::strategy::Strategy; use proptest::string::string_regex; -use std::ops::Add; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use stun_codec::rfc5766::attributes::{ChannelNumber, Lifetime, RequestedTransport}; use stun_codec::TransactionId; use uuid::Uuid; @@ -65,12 +64,3 @@ pub fn username_salt() -> impl Strategy { pub fn nonce() -> impl Strategy { any::().prop_map(Uuid::from_u128) } - -/// We let "now" begin somewhere around 2000 up until 2100. -pub fn now() -> impl Strategy { - const YEAR: u64 = 60 * 60 * 24 * 365; - - (30 * YEAR..100 * YEAR) - .prop_map(Duration::from_secs) - .prop_map(|duration| SystemTime::UNIX_EPOCH.add(duration)) -} diff --git a/rust/relay/src/server.rs b/rust/relay/src/server.rs index a765d5e0a..35e6306fa 100644 --- a/rust/relay/src/server.rs +++ b/rust/relay/src/server.rs @@ -19,7 +19,7 @@ use secrecy::SecretString; use std::collections::{HashMap, VecDeque}; use std::hash::Hash; use std::net::{IpAddr, SocketAddr}; -use std::time::{Duration, SystemTime}; +use std::time::{Duration, Instant, SystemTime}; use stun_codec::rfc5389::attributes::{ ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress, }; @@ -230,7 +230,7 @@ where &mut self, bytes: &[u8], sender: ClientSocket, - now: SystemTime, + now: Instant, ) -> Option<(AllocationPort, PeerSocket)> { tracing::trace!(target: "wire", num_bytes = %bytes.len()); @@ -268,7 +268,7 @@ where &mut self, message: ClientMessage, sender: ClientSocket, - now: SystemTime, + now: Instant, ) -> Option<(AllocationPort, PeerSocket)> { let result = match message { ClientMessage::Allocate(request) => self.handle_allocate_request(request, sender, now), @@ -277,14 +277,14 @@ where self.handle_channel_bind_request(request, sender, now) } ClientMessage::CreatePermission(request) => { - self.handle_create_permission_request(request, sender, now) + self.handle_create_permission_request(request, sender) } ClientMessage::Binding(request) => { self.handle_binding_request(request, sender); return None; } ClientMessage::ChannelData(msg) => { - return self.handle_channel_data_message(msg, sender, now); + return self.handle_channel_data_message(msg, sender); } }; @@ -368,7 +368,7 @@ where } // TODO: It might be worth to do some caching here? - pub fn poll_timeout(&self) -> Option { + pub fn poll_timeout(&self) -> Option { let channel_expiries = self.channels_by_client_and_number.values().map(|c| { if c.bound { c.expiry @@ -383,7 +383,7 @@ where .fold(None, |current, next| earliest(current, Some(next))) } - pub fn handle_timeout(&mut self, now: SystemTime) { + pub fn handle_timeout(&mut self, now: Instant) { let expired_allocations = self .allocations .values() @@ -435,9 +435,9 @@ where &mut self, request: Allocate, sender: ClientSocket, - now: SystemTime, + now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request, now)?; + self.verify_auth(&request)?; if let Some(allocation) = self.allocations.get(&sender) { Span::current().record("allocation", display(&allocation.port)); @@ -545,9 +545,9 @@ where &mut self, request: Refresh, sender: ClientSocket, - now: SystemTime, + now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request, now)?; + self.verify_auth(&request)?; // TODO: Verify that this is the correct error code. let allocation = self @@ -594,9 +594,9 @@ where &mut self, request: ChannelBind, sender: ClientSocket, - now: SystemTime, + now: Instant, ) -> Result<(), Message> { - self.verify_auth(&request, now)?; + self.verify_auth(&request)?; let allocation = self .allocations @@ -678,14 +678,13 @@ where /// /// This TURN server implementation does not support relaying data other than through channels. /// Thus, creating a permission is a no-op that always succeeds. - #[tracing::instrument(level = "debug", skip(self, message, now), fields(%sender))] + #[tracing::instrument(level = "debug", skip_all, fields(%sender))] fn handle_create_permission_request( &mut self, message: CreatePermission, sender: ClientSocket, - now: SystemTime, ) -> Result<(), Message> { - self.verify_auth(&message, now)?; + self.verify_auth(&message)?; self.send_message( create_permission_success_response(message.transaction_id()), @@ -699,7 +698,6 @@ where &mut self, message: ChannelData, sender: ClientSocket, - _: SystemTime, ) -> Option<(AllocationPort, PeerSocket)> { let channel_number = message.channel(); let data = message.data(); @@ -735,7 +733,6 @@ where fn verify_auth( &mut self, request: &(impl StunRequest + ProtectedRequest), - now: SystemTime, ) -> Result<(), Message> { let message_integrity = request .message_integrity() @@ -757,7 +754,7 @@ where .map_err(|_| error_response(StaleNonce, request))?; message_integrity - .verify(&self.auth_secret, username.name(), now) + .verify(&self.auth_secret, username.name(), SystemTime::now()) // This is impure but we don't need to control this in our tests. .map_err(|_| error_response(Unauthorized, request))?; Ok(()) @@ -765,7 +762,7 @@ where fn create_new_allocation( &mut self, - now: SystemTime, + now: Instant, lifetime: &Lifetime, first_relay_addr: IpAddr, second_relay_addr: Option, @@ -801,7 +798,7 @@ where requested_channel: ChannelNumber, peer: PeerSocket, id: AllocationPort, - now: SystemTime, + now: Instant, ) { let expiry = now + CHANNEL_BINDING_DURATION; @@ -945,7 +942,7 @@ fn create_permission_success_response(transaction_id: TransactionId) -> Message< struct Allocation { /// Data arriving on this port will be forwarded to the client iff there is an active data channel. port: AllocationPort, - expires_at: SystemTime, + expires_at: Instant, first_relay_addr: IpAddr, second_relay_addr: Option, @@ -953,7 +950,7 @@ struct Allocation { struct Channel { /// When the channel expires. - expiry: SystemTime, + expiry: Instant, /// The address of the peer that the channel is bound to. peer_address: PeerSocket, @@ -974,15 +971,15 @@ struct Channel { } impl Channel { - fn refresh(&mut self, now: SystemTime) { + fn refresh(&mut self, now: Instant) { self.expiry = now + CHANNEL_BINDING_DURATION; } - fn is_expired(&self, now: SystemTime) -> bool { + fn is_expired(&self, now: Instant) -> bool { self.expiry <= now } - fn can_be_deleted(&self, now: SystemTime) -> bool { + fn can_be_deleted(&self, now: Instant) -> bool { self.expiry + CHANNEL_REBIND_TIMEOUT <= now } } @@ -1004,7 +1001,7 @@ impl Allocation { } impl Allocation { - fn is_expired(&self, now: SystemTime) -> bool { + fn is_expired(&self, now: Instant) -> bool { self.expires_at <= now } } @@ -1175,7 +1172,7 @@ stun_codec::define_attribute_enums!( ] ); -fn earliest(left: Option, right: Option) -> Option { +fn earliest(left: Option, right: Option) -> Option { match (left, right) { (None, None) => None, (Some(left), Some(right)) => Some(std::cmp::min(left, right)), diff --git a/rust/relay/src/sleep.rs b/rust/relay/src/sleep.rs index ced0d6367..fbcf23086 100644 --- a/rust/relay/src/sleep.rs +++ b/rust/relay/src/sleep.rs @@ -1,9 +1,7 @@ -use futures::future::BoxFuture; -use futures::FutureExt; use std::future::Future; use std::pin::Pin; use std::task::{ready, Context, Poll, Waker}; -use std::time::SystemTime; +use std::time::Instant; /// A future that sleeps until a given instant. /// @@ -12,43 +10,42 @@ use std::time::SystemTime; #[derive(Default)] pub struct Sleep { /// The inner sleep future. Boxed for convenience to make [`Sleep`] implement [`Unpin`]. - inner: Option>, - current_deadline: Option, + inner: Option>>, waker: Option, } impl Sleep { - pub fn reset(self: Pin<&mut Self>, deadline: SystemTime) { + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { let this = self.get_mut(); + let deadline = tokio::time::Instant::from_std(deadline); - if this.current_deadline.is_some_and(|c| c == deadline) { - return; - } + match this.inner.as_mut() { + Some(sleep) if sleep.deadline() != deadline => sleep.as_mut().reset(deadline), + Some(_) => (), + None => { + this.inner = Some(Box::pin(tokio::time::sleep_until(deadline))); - let duration = deadline - .duration_since(SystemTime::now()) - .unwrap_or_default(); - - this.inner = Some(tokio::time::sleep(duration).boxed()); - this.current_deadline = Some(deadline); - - if let Some(waker) = this.waker.take() { - waker.wake(); + if let Some(waker) = this.waker.take() { + waker.wake(); + } + } } } } impl Future for Sleep { - type Output = (); + type Output = Instant; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let Some(inner) = &mut this.inner { + let deadline = inner.deadline(); + ready!(Pin::new(inner).poll(cx)); this.inner = None; - return Poll::Ready(()); + return Poll::Ready(deadline.into()); } this.waker = Some(cx.waker().clone()); @@ -60,6 +57,7 @@ impl Future for Sleep { #[cfg(test)] mod tests { use super::*; + use futures::FutureExt as _; use std::pin::pin; use std::time::Duration; @@ -75,7 +73,7 @@ mod tests { #[tokio::test] async fn finished_sleep_returns_pending() { let mut sleep = Sleep::default(); - Pin::new(&mut sleep).reset(SystemTime::now() + Duration::from_millis(100)); + Pin::new(&mut sleep).reset(Instant::now() + Duration::from_millis(100)); tokio::time::sleep(Duration::from_millis(200)).await; @@ -91,7 +89,7 @@ mod tests { #[tokio::test] async fn does_not_crash_and_fires_immediately_when_reset_to_past() { let mut sleep = Sleep::default(); - Pin::new(&mut sleep).reset(SystemTime::now() - Duration::from_millis(100)); + Pin::new(&mut sleep).reset(Instant::now() - Duration::from_millis(100)); sleep.await; } diff --git a/rust/relay/tests/regression.rs b/rust/relay/tests/regression.rs index c80f39322..779675a1c 100644 --- a/rust/relay/tests/regression.rs +++ b/rust/relay/tests/regression.rs @@ -7,7 +7,7 @@ use rand::rngs::mock::StepRng; use secrecy::SecretString; use std::iter; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::time::{Duration, SystemTime}; +use std::time::{Duration, Instant, SystemTime}; use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress}; use stun_codec::rfc5389::errors::Unauthorized; use stun_codec::rfc5389::methods::BINDING; @@ -30,7 +30,7 @@ fn can_answer_stun_request_from_ip4_address( let transaction_id = request.transaction_id(); server.assert_commands( - from_client(source, request, SystemTime::now()), + from_client(source, request, Instant::now()), [send_message( source, binding_response(transaction_id, source), @@ -45,9 +45,10 @@ fn deallocate_once_time_expired( #[strategy(firezone_relay::proptest::username_salt())] username_salt: String, source: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let mut server = TestServer::new(public_relay_addr).with_nonce(nonce); let secret = server.auth_secret(); @@ -57,7 +58,7 @@ fn deallocate_once_time_expired( Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), secret, nonce, ), @@ -90,8 +91,9 @@ fn unauthenticated_allocate_triggers_authentication( #[strategy(firezone_relay::proptest::username_salt())] username_salt: String, source: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, ) { + let now = Instant::now(); + // Nonces are generated randomly and we control the randomness in the test, thus this is deterministic. let first_nonce = Uuid::from_u128(0x0); @@ -116,7 +118,7 @@ fn unauthenticated_allocate_triggers_authentication( Allocate::new_authenticated_udp_implicit_ip4( transaction_id, Some(lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, first_nonce, ), @@ -146,9 +148,10 @@ fn when_refreshed_in_time_allocation_does_not_expire( #[strategy(firezone_relay::proptest::username_salt())] username_salt: String, source: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let mut server = TestServer::new(public_relay_addr).with_nonce(nonce); let secret = server.auth_secret().to_owned(); let first_wake = now + allocate_lifetime.lifetime(); @@ -159,7 +162,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -192,7 +195,7 @@ fn when_refreshed_in_time_allocation_does_not_expire( Refresh::new( refresh_transaction_id, Some(refresh_lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -222,9 +225,10 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( #[strategy(firezone_relay::proptest::username_salt())] username_salt: String, source: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let mut server = TestServer::new(public_relay_addr).with_nonce(nonce); let secret = server.auth_secret().to_owned(); let first_wake = now + allocate_lifetime.lifetime(); @@ -235,7 +239,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(allocate_lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -267,7 +271,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete( Refresh::new( refresh_transaction_id, Some(Lifetime::new(Duration::ZERO).unwrap()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -303,11 +307,12 @@ fn ping_pong_relay( source: SocketAddrV4, peer: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, peer_to_client_ping: [u8; 32], #[strategy(firezone_relay::proptest::channel_data())] client_to_peer_ping: ChannelData<'static>, #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let _ = env_logger::try_init(); let mut server = TestServer::new(public_relay_addr).with_nonce(nonce); @@ -320,7 +325,7 @@ fn ping_pong_relay( Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -355,7 +360,7 @@ fn ping_pong_relay( channel_bind_transaction_id, client_to_peer_ping.channel(), XorPeerAddress::new(peer.into()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -413,9 +418,10 @@ fn allows_rebind_channel_after_expiry( peer: SocketAddrV4, peer2: SocketAddrV4, public_relay_addr: Ipv4Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let _ = env_logger::try_init(); let mut server = TestServer::new(public_relay_addr).with_nonce(nonce); @@ -428,7 +434,7 @@ fn allows_rebind_channel_after_expiry( Allocate::new_authenticated_udp_implicit_ip4( allocate_transaction_id, Some(lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -463,7 +469,7 @@ fn allows_rebind_channel_after_expiry( channel_bind_transaction_id, channel, XorPeerAddress::new(peer.into()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -498,7 +504,7 @@ fn allows_rebind_channel_after_expiry( channel_bind_2_transaction_id, channel, XorPeerAddress::new(peer2.into()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -527,11 +533,12 @@ fn ping_pong_ip6_relay( peer: SocketAddrV6, public_relay_ip4_addr: Ipv4Addr, public_relay_ip6_addr: Ipv6Addr, - #[strategy(firezone_relay::proptest::now())] now: SystemTime, peer_to_client_ping: [u8; 32], mut client_to_peer_ping: [u8; 36], #[strategy(firezone_relay::proptest::nonce())] nonce: Uuid, ) { + let now = Instant::now(); + let _ = env_logger::try_init(); let mut server = @@ -545,7 +552,7 @@ fn ping_pong_ip6_relay( Allocate::new_authenticated_udp_ip6( allocate_transaction_id, Some(lifetime.clone()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -580,7 +587,7 @@ fn ping_pong_ip6_relay( channel_bind_transaction_id, channel, XorPeerAddress::new(peer.into()), - valid_username(now, &username_salt), + valid_username(&username_salt), &secret, nonce, ), @@ -719,8 +726,8 @@ impl TestServer { } } -fn valid_username(now: SystemTime, salt: &str) -> Username { - let now_unix = now +fn valid_username(salt: &str) -> Username { + let now_unix = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); @@ -793,19 +800,19 @@ fn parse_message(message: &[u8]) -> Message { } enum Input<'a> { - Client(ClientSocket, ClientMessage<'a>, SystemTime), - Time(SystemTime), + Client(ClientSocket, ClientMessage<'a>, Instant), + Time(Instant), } fn from_client<'a>( from: impl Into, message: impl Into>, - now: SystemTime, + now: Instant, ) -> Input<'a> { Input::Client(ClientSocket::new(from.into()), message.into(), now) } -fn forward_time_to<'a>(when: SystemTime) -> Input<'a> { +fn forward_time_to<'a>(when: Instant) -> Input<'a> { Input::Time(when) }