refactor(relay): favor Instant over SystemTime (#4468)

This one is a bit tricky. Our auth scheme requires me to know the
current time as a UNIX timestamp and that I can only get from
`SystemTime` but not `Instant`. The `Server` is meant to be SANS-IO,
including the current time so technically, I would have to pass that in
as a parameter.

I ended up settling on a compromise of making the auth verification
impure and internally calling `SystemTime::now`. That results in a much
nicer API and allows us to use `Instant` for everything else, e.g.
expiry of channel bindings, allocations etc.

Resolves: #4464.
This commit is contained in:
Thomas Eizinger
2024-04-09 09:37:19 +10:00
committed by GitHub
parent 5fa27ecc66
commit 8900e263ca
5 changed files with 90 additions and 99 deletions

View File

@@ -16,7 +16,7 @@ use secrecy::{Secret, SecretString};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::task::Poll;
use std::time::{Duration, SystemTime};
use std::time::{Duration, Instant};
use tracing::{level_filters::LevelFilter, Instrument, Subscriber};
use tracing_core::Dispatch;
use tracing_stackdriver::CloudTraceConfiguration;
@@ -350,8 +350,6 @@ where
}
fn poll(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
let now = SystemTime::now();
loop {
// Priority 1: Execute the pending commands of the server.
if let Some(next_command) = self.server.next_command() {
@@ -417,10 +415,11 @@ where
from,
packet,
})) => {
if let Some((port, peer)) =
self.server
.handle_client_input(packet, ClientSocket::new(from), now)
{
if let Some((port, peer)) = self.server.handle_client_input(
packet,
ClientSocket::new(from),
Instant::now(),
) {
// Re-parse as `ChannelData` if we should relay it.
let payload = ChannelData::parse(packet)
.expect("valid ChannelData if we should relay it")
@@ -476,8 +475,8 @@ where
}
// Priority 4: Handle time-sensitive tasks:
if self.sleep.poll_unpin(cx).is_ready() {
self.server.handle_timeout(now);
if let Poll::Ready(deadline) = self.sleep.poll_unpin(cx) {
self.server.handle_timeout(deadline);
continue; // Handle potentially new commands.
}

View File

@@ -4,8 +4,7 @@ use proptest::arbitrary::any;
use proptest::strategy::Just;
use proptest::strategy::Strategy;
use proptest::string::string_regex;
use std::ops::Add;
use std::time::{Duration, SystemTime};
use std::time::Duration;
use stun_codec::rfc5766::attributes::{ChannelNumber, Lifetime, RequestedTransport};
use stun_codec::TransactionId;
use uuid::Uuid;
@@ -65,12 +64,3 @@ pub fn username_salt() -> impl Strategy<Value = String> {
pub fn nonce() -> impl Strategy<Value = Uuid> {
any::<u128>().prop_map(Uuid::from_u128)
}
/// We let "now" begin somewhere around 2000 up until 2100.
pub fn now() -> impl Strategy<Value = SystemTime> {
const YEAR: u64 = 60 * 60 * 24 * 365;
(30 * YEAR..100 * YEAR)
.prop_map(Duration::from_secs)
.prop_map(|duration| SystemTime::UNIX_EPOCH.add(duration))
}

View File

@@ -19,7 +19,7 @@ use secrecy::SecretString;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, SystemTime};
use std::time::{Duration, Instant, SystemTime};
use stun_codec::rfc5389::attributes::{
ErrorCode, MessageIntegrity, Nonce, Realm, Username, XorMappedAddress,
};
@@ -230,7 +230,7 @@ where
&mut self,
bytes: &[u8],
sender: ClientSocket,
now: SystemTime,
now: Instant,
) -> Option<(AllocationPort, PeerSocket)> {
tracing::trace!(target: "wire", num_bytes = %bytes.len());
@@ -268,7 +268,7 @@ where
&mut self,
message: ClientMessage,
sender: ClientSocket,
now: SystemTime,
now: Instant,
) -> Option<(AllocationPort, PeerSocket)> {
let result = match message {
ClientMessage::Allocate(request) => self.handle_allocate_request(request, sender, now),
@@ -277,14 +277,14 @@ where
self.handle_channel_bind_request(request, sender, now)
}
ClientMessage::CreatePermission(request) => {
self.handle_create_permission_request(request, sender, now)
self.handle_create_permission_request(request, sender)
}
ClientMessage::Binding(request) => {
self.handle_binding_request(request, sender);
return None;
}
ClientMessage::ChannelData(msg) => {
return self.handle_channel_data_message(msg, sender, now);
return self.handle_channel_data_message(msg, sender);
}
};
@@ -368,7 +368,7 @@ where
}
// TODO: It might be worth to do some caching here?
pub fn poll_timeout(&self) -> Option<SystemTime> {
pub fn poll_timeout(&self) -> Option<Instant> {
let channel_expiries = self.channels_by_client_and_number.values().map(|c| {
if c.bound {
c.expiry
@@ -383,7 +383,7 @@ where
.fold(None, |current, next| earliest(current, Some(next)))
}
pub fn handle_timeout(&mut self, now: SystemTime) {
pub fn handle_timeout(&mut self, now: Instant) {
let expired_allocations = self
.allocations
.values()
@@ -435,9 +435,9 @@ where
&mut self,
request: Allocate,
sender: ClientSocket,
now: SystemTime,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
self.verify_auth(&request)?;
if let Some(allocation) = self.allocations.get(&sender) {
Span::current().record("allocation", display(&allocation.port));
@@ -545,9 +545,9 @@ where
&mut self,
request: Refresh,
sender: ClientSocket,
now: SystemTime,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
self.verify_auth(&request)?;
// TODO: Verify that this is the correct error code.
let allocation = self
@@ -594,9 +594,9 @@ where
&mut self,
request: ChannelBind,
sender: ClientSocket,
now: SystemTime,
now: Instant,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&request, now)?;
self.verify_auth(&request)?;
let allocation = self
.allocations
@@ -678,14 +678,13 @@ where
///
/// This TURN server implementation does not support relaying data other than through channels.
/// Thus, creating a permission is a no-op that always succeeds.
#[tracing::instrument(level = "debug", skip(self, message, now), fields(%sender))]
#[tracing::instrument(level = "debug", skip_all, fields(%sender))]
fn handle_create_permission_request(
&mut self,
message: CreatePermission,
sender: ClientSocket,
now: SystemTime,
) -> Result<(), Message<Attribute>> {
self.verify_auth(&message, now)?;
self.verify_auth(&message)?;
self.send_message(
create_permission_success_response(message.transaction_id()),
@@ -699,7 +698,6 @@ where
&mut self,
message: ChannelData,
sender: ClientSocket,
_: SystemTime,
) -> Option<(AllocationPort, PeerSocket)> {
let channel_number = message.channel();
let data = message.data();
@@ -735,7 +733,6 @@ where
fn verify_auth(
&mut self,
request: &(impl StunRequest + ProtectedRequest),
now: SystemTime,
) -> Result<(), Message<Attribute>> {
let message_integrity = request
.message_integrity()
@@ -757,7 +754,7 @@ where
.map_err(|_| error_response(StaleNonce, request))?;
message_integrity
.verify(&self.auth_secret, username.name(), now)
.verify(&self.auth_secret, username.name(), SystemTime::now()) // This is impure but we don't need to control this in our tests.
.map_err(|_| error_response(Unauthorized, request))?;
Ok(())
@@ -765,7 +762,7 @@ where
fn create_new_allocation(
&mut self,
now: SystemTime,
now: Instant,
lifetime: &Lifetime,
first_relay_addr: IpAddr,
second_relay_addr: Option<IpAddr>,
@@ -801,7 +798,7 @@ where
requested_channel: ChannelNumber,
peer: PeerSocket,
id: AllocationPort,
now: SystemTime,
now: Instant,
) {
let expiry = now + CHANNEL_BINDING_DURATION;
@@ -945,7 +942,7 @@ fn create_permission_success_response(transaction_id: TransactionId) -> Message<
struct Allocation {
/// Data arriving on this port will be forwarded to the client iff there is an active data channel.
port: AllocationPort,
expires_at: SystemTime,
expires_at: Instant,
first_relay_addr: IpAddr,
second_relay_addr: Option<IpAddr>,
@@ -953,7 +950,7 @@ struct Allocation {
struct Channel {
/// When the channel expires.
expiry: SystemTime,
expiry: Instant,
/// The address of the peer that the channel is bound to.
peer_address: PeerSocket,
@@ -974,15 +971,15 @@ struct Channel {
}
impl Channel {
fn refresh(&mut self, now: SystemTime) {
fn refresh(&mut self, now: Instant) {
self.expiry = now + CHANNEL_BINDING_DURATION;
}
fn is_expired(&self, now: SystemTime) -> bool {
fn is_expired(&self, now: Instant) -> bool {
self.expiry <= now
}
fn can_be_deleted(&self, now: SystemTime) -> bool {
fn can_be_deleted(&self, now: Instant) -> bool {
self.expiry + CHANNEL_REBIND_TIMEOUT <= now
}
}
@@ -1004,7 +1001,7 @@ impl Allocation {
}
impl Allocation {
fn is_expired(&self, now: SystemTime) -> bool {
fn is_expired(&self, now: Instant) -> bool {
self.expires_at <= now
}
}
@@ -1175,7 +1172,7 @@ stun_codec::define_attribute_enums!(
]
);
fn earliest(left: Option<SystemTime>, right: Option<SystemTime>) -> Option<SystemTime> {
fn earliest(left: Option<Instant>, right: Option<Instant>) -> Option<Instant> {
match (left, right) {
(None, None) => None,
(Some(left), Some(right)) => Some(std::cmp::min(left, right)),

View File

@@ -1,9 +1,7 @@
use futures::future::BoxFuture;
use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll, Waker};
use std::time::SystemTime;
use std::time::Instant;
/// A future that sleeps until a given instant.
///
@@ -12,43 +10,42 @@ use std::time::SystemTime;
#[derive(Default)]
pub struct Sleep {
/// The inner sleep future. Boxed for convenience to make [`Sleep`] implement [`Unpin`].
inner: Option<BoxFuture<'static, ()>>,
current_deadline: Option<SystemTime>,
inner: Option<Pin<Box<tokio::time::Sleep>>>,
waker: Option<Waker>,
}
impl Sleep {
pub fn reset(self: Pin<&mut Self>, deadline: SystemTime) {
pub fn reset(self: Pin<&mut Self>, deadline: Instant) {
let this = self.get_mut();
let deadline = tokio::time::Instant::from_std(deadline);
if this.current_deadline.is_some_and(|c| c == deadline) {
return;
}
match this.inner.as_mut() {
Some(sleep) if sleep.deadline() != deadline => sleep.as_mut().reset(deadline),
Some(_) => (),
None => {
this.inner = Some(Box::pin(tokio::time::sleep_until(deadline)));
let duration = deadline
.duration_since(SystemTime::now())
.unwrap_or_default();
this.inner = Some(tokio::time::sleep(duration).boxed());
this.current_deadline = Some(deadline);
if let Some(waker) = this.waker.take() {
waker.wake();
if let Some(waker) = this.waker.take() {
waker.wake();
}
}
}
}
}
impl Future for Sleep {
type Output = ();
type Output = Instant;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(inner) = &mut this.inner {
let deadline = inner.deadline();
ready!(Pin::new(inner).poll(cx));
this.inner = None;
return Poll::Ready(());
return Poll::Ready(deadline.into());
}
this.waker = Some(cx.waker().clone());
@@ -60,6 +57,7 @@ impl Future for Sleep {
#[cfg(test)]
mod tests {
use super::*;
use futures::FutureExt as _;
use std::pin::pin;
use std::time::Duration;
@@ -75,7 +73,7 @@ mod tests {
#[tokio::test]
async fn finished_sleep_returns_pending() {
let mut sleep = Sleep::default();
Pin::new(&mut sleep).reset(SystemTime::now() + Duration::from_millis(100));
Pin::new(&mut sleep).reset(Instant::now() + Duration::from_millis(100));
tokio::time::sleep(Duration::from_millis(200)).await;
@@ -91,7 +89,7 @@ mod tests {
#[tokio::test]
async fn does_not_crash_and_fires_immediately_when_reset_to_past() {
let mut sleep = Sleep::default();
Pin::new(&mut sleep).reset(SystemTime::now() - Duration::from_millis(100));
Pin::new(&mut sleep).reset(Instant::now() - Duration::from_millis(100));
sleep.await;
}

View File

@@ -7,7 +7,7 @@ use rand::rngs::mock::StepRng;
use secrecy::SecretString;
use std::iter;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::{Duration, SystemTime};
use std::time::{Duration, Instant, SystemTime};
use stun_codec::rfc5389::attributes::{ErrorCode, Nonce, Realm, Username, XorMappedAddress};
use stun_codec::rfc5389::errors::Unauthorized;
use stun_codec::rfc5389::methods::BINDING;
@@ -30,7 +30,7 @@ fn can_answer_stun_request_from_ip4_address(
let transaction_id = request.transaction_id();
server.assert_commands(
from_client(source, request, SystemTime::now()),
from_client(source, request, Instant::now()),
[send_message(
source,
binding_response(transaction_id, source),
@@ -45,9 +45,10 @@ fn deallocate_once_time_expired(
#[strategy(firezone_relay::proptest::username_salt())] username_salt: String,
source: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let mut server = TestServer::new(public_relay_addr).with_nonce(nonce);
let secret = server.auth_secret();
@@ -57,7 +58,7 @@ fn deallocate_once_time_expired(
Allocate::new_authenticated_udp_implicit_ip4(
transaction_id,
Some(lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
secret,
nonce,
),
@@ -90,8 +91,9 @@ fn unauthenticated_allocate_triggers_authentication(
#[strategy(firezone_relay::proptest::username_salt())] username_salt: String,
source: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
) {
let now = Instant::now();
// Nonces are generated randomly and we control the randomness in the test, thus this is deterministic.
let first_nonce = Uuid::from_u128(0x0);
@@ -116,7 +118,7 @@ fn unauthenticated_allocate_triggers_authentication(
Allocate::new_authenticated_udp_implicit_ip4(
transaction_id,
Some(lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
first_nonce,
),
@@ -146,9 +148,10 @@ fn when_refreshed_in_time_allocation_does_not_expire(
#[strategy(firezone_relay::proptest::username_salt())] username_salt: String,
source: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let mut server = TestServer::new(public_relay_addr).with_nonce(nonce);
let secret = server.auth_secret().to_owned();
let first_wake = now + allocate_lifetime.lifetime();
@@ -159,7 +162,7 @@ fn when_refreshed_in_time_allocation_does_not_expire(
Allocate::new_authenticated_udp_implicit_ip4(
allocate_transaction_id,
Some(allocate_lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -192,7 +195,7 @@ fn when_refreshed_in_time_allocation_does_not_expire(
Refresh::new(
refresh_transaction_id,
Some(refresh_lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -222,9 +225,10 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete(
#[strategy(firezone_relay::proptest::username_salt())] username_salt: String,
source: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let mut server = TestServer::new(public_relay_addr).with_nonce(nonce);
let secret = server.auth_secret().to_owned();
let first_wake = now + allocate_lifetime.lifetime();
@@ -235,7 +239,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete(
Allocate::new_authenticated_udp_implicit_ip4(
allocate_transaction_id,
Some(allocate_lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -267,7 +271,7 @@ fn when_receiving_lifetime_0_for_existing_allocation_then_delete(
Refresh::new(
refresh_transaction_id,
Some(Lifetime::new(Duration::ZERO).unwrap()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -303,11 +307,12 @@ fn ping_pong_relay(
source: SocketAddrV4,
peer: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
peer_to_client_ping: [u8; 32],
#[strategy(firezone_relay::proptest::channel_data())] client_to_peer_ping: ChannelData<'static>,
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let _ = env_logger::try_init();
let mut server = TestServer::new(public_relay_addr).with_nonce(nonce);
@@ -320,7 +325,7 @@ fn ping_pong_relay(
Allocate::new_authenticated_udp_implicit_ip4(
allocate_transaction_id,
Some(lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -355,7 +360,7 @@ fn ping_pong_relay(
channel_bind_transaction_id,
client_to_peer_ping.channel(),
XorPeerAddress::new(peer.into()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -413,9 +418,10 @@ fn allows_rebind_channel_after_expiry(
peer: SocketAddrV4,
peer2: SocketAddrV4,
public_relay_addr: Ipv4Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let _ = env_logger::try_init();
let mut server = TestServer::new(public_relay_addr).with_nonce(nonce);
@@ -428,7 +434,7 @@ fn allows_rebind_channel_after_expiry(
Allocate::new_authenticated_udp_implicit_ip4(
allocate_transaction_id,
Some(lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -463,7 +469,7 @@ fn allows_rebind_channel_after_expiry(
channel_bind_transaction_id,
channel,
XorPeerAddress::new(peer.into()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -498,7 +504,7 @@ fn allows_rebind_channel_after_expiry(
channel_bind_2_transaction_id,
channel,
XorPeerAddress::new(peer2.into()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -527,11 +533,12 @@ fn ping_pong_ip6_relay(
peer: SocketAddrV6,
public_relay_ip4_addr: Ipv4Addr,
public_relay_ip6_addr: Ipv6Addr,
#[strategy(firezone_relay::proptest::now())] now: SystemTime,
peer_to_client_ping: [u8; 32],
mut client_to_peer_ping: [u8; 36],
#[strategy(firezone_relay::proptest::nonce())] nonce: Uuid,
) {
let now = Instant::now();
let _ = env_logger::try_init();
let mut server =
@@ -545,7 +552,7 @@ fn ping_pong_ip6_relay(
Allocate::new_authenticated_udp_ip6(
allocate_transaction_id,
Some(lifetime.clone()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -580,7 +587,7 @@ fn ping_pong_ip6_relay(
channel_bind_transaction_id,
channel,
XorPeerAddress::new(peer.into()),
valid_username(now, &username_salt),
valid_username(&username_salt),
&secret,
nonce,
),
@@ -719,8 +726,8 @@ impl TestServer {
}
}
fn valid_username(now: SystemTime, salt: &str) -> Username {
let now_unix = now
fn valid_username(salt: &str) -> Username {
let now_unix = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
@@ -793,19 +800,19 @@ fn parse_message(message: &[u8]) -> Message<Attribute> {
}
enum Input<'a> {
Client(ClientSocket, ClientMessage<'a>, SystemTime),
Time(SystemTime),
Client(ClientSocket, ClientMessage<'a>, Instant),
Time(Instant),
}
fn from_client<'a>(
from: impl Into<SocketAddr>,
message: impl Into<ClientMessage<'a>>,
now: SystemTime,
now: Instant,
) -> Input<'a> {
Input::Client(ClientSocket::new(from.into()), message.into(), now)
}
fn forward_time_to<'a>(when: SystemTime) -> Input<'a> {
fn forward_time_to<'a>(when: Instant) -> Input<'a> {
Input::Time(when)
}