From 99aa973db472dedd1e61475220d2cc4584102061 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 20 Aug 2024 23:17:55 +0100 Subject: [PATCH] chore(connlib): reduce buffer sizes (#6360) Currently, `snownet` allocates a 65KB buffer per connection as a scratch-space for encrypting packets. 65KB is the theoretical limit of a UDP packet. In practice, the largest UDP packets we send are 1336 bytes due to the MTU of 1280 set on our TUN interface and various overheads for WG, TURN channels and NAT46. Thus, it is unnecessary to allocate such a large buffer per connection. For gateways with many connections, reducing these buffers results in a smaller memory footprint. Additionally, any UDP packets larger than this buffer could be an indicator of a DoS attack and we can thus drop them without processing. A legitimate client / gateway will never send a packet larger than that. --- rust/connlib/snownet/src/node.rs | 14 ++++++-------- rust/connlib/snownet/tests/lib.rs | 4 ++++ rust/connlib/tunnel/src/client.rs | 4 ++-- rust/connlib/tunnel/src/gateway.rs | 4 ++-- rust/connlib/tunnel/src/io.rs | 24 ++++++++++++++++++------ rust/connlib/tunnel/src/lib.rs | 29 +++++++++++++++++------------ 6 files changed, 49 insertions(+), 30 deletions(-) diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 9e232a887..5683f3f2c 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -40,8 +40,6 @@ const CANDIDATE_TIMEOUT: Duration = Duration::from_secs(10); /// How long we will at most wait for an [`Answer`] from the remote. pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(20); -const MAX_UDP_SIZE: usize = (1 << 16) - 1; - /// Manages a set of wireguard connections for a server. pub type ServerNode = Node; /// Manages a set of wireguard connections for a client. @@ -96,7 +94,7 @@ pub struct Node { connections: Connections, pending_events: VecDeque>, - buffer: Box<[u8; MAX_UDP_SIZE]>, + buffer: Vec, stats: NodeStats, @@ -127,7 +125,7 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { + pub fn new(private_key: StaticSecret, buf_size: usize, seed: [u8; 32]) -> Self { let public_key = &(&private_key).into(); Self { rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. @@ -141,7 +139,7 @@ where buffered_transmits: VecDeque::default(), next_rate_limiter_reset: None, pending_events: VecDeque::default(), - buffer: Box::new([0u8; MAX_UDP_SIZE]), + buffer: vec![0; buf_size], allocations: Default::default(), connections: Default::default(), stats: Default::default(), @@ -372,7 +370,7 @@ where tracing::warn!(%relay, "No allocation"); return Ok(None); }; - let packet = &mut self.buffer.as_mut()[..packet_end]; + let packet = &mut self.buffer[..packet_end]; let Some(transmit) = allocation.encode_to_borrowed_transmit(peer, packet, now) else { @@ -574,7 +572,7 @@ where ), next_timer_update: now, stats: Default::default(), - buffer: Box::new([0u8; MAX_UDP_SIZE]), + buffer: vec![0; self.buffer.capacity()], intent_sent_at, signalling_completed_at: now, remote_pub_key: remote, @@ -1346,7 +1344,7 @@ struct Connection { intent_sent_at: Instant, signalling_completed_at: Instant, - buffer: Box<[u8; MAX_UDP_SIZE]>, + buffer: Vec, last_outgoing: Instant, last_incoming: Instant, diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 343ad6a3b..16b6ec6c9 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -75,12 +75,14 @@ fn only_generate_candidate_event_after_answer() { let mut alice = ClientNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); alice.add_local_host_candidate(local_candidate).unwrap(); let mut bob = ServerNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); @@ -108,10 +110,12 @@ fn only_generate_candidate_event_after_answer() { fn alice_and_bob() -> (ClientNode, ServerNode) { let alice = ClientNode::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); let bob = ServerNode::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index e8b0c9c55..73f365f5e 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,6 +1,6 @@ -use crate::dns; use crate::dns::StubResolver; use crate::peer_store::PeerStore; +use crate::{dns, BUF_SIZE}; use anyhow::Context; use bimap::BiMap; use connlib_shared::callbacks::Status; @@ -309,7 +309,7 @@ impl ClientState { buffered_events: Default::default(), interface_config: Default::default(), buffered_packets: Default::default(), - node: ClientNode::new(private_key.into(), seed), + node: ClientNode::new(private_key.into(), BUF_SIZE, seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 22d21c2b8..d0a727e65 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,7 +1,7 @@ use crate::peer::ClientOnGateway; use crate::peer_store::PeerStore; use crate::utils::earliest; -use crate::{GatewayEvent, GatewayTunnel}; +use crate::{GatewayEvent, GatewayTunnel, BUF_SIZE}; use anyhow::{bail, Context}; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; @@ -146,7 +146,7 @@ impl GatewayState { pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into(), seed), + node: ServerNode::new(private_key.into(), BUF_SIZE, seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index b5a3ea89f..a9221bebe 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,7 @@ use crate::{ device_channel::Device, sockets::{NoInterfaces, Sockets}, + BUF_SIZE, }; use futures_util::FutureExt as _; use ip_packet::{IpPacket, MutableIpPacket}; @@ -54,15 +55,15 @@ impl Io { }) } - pub fn poll<'b>( + pub fn poll<'b1, 'b2>( &mut self, cx: &mut Context<'_>, - ip4_buffer: &'b mut [u8], - ip6_bffer: &'b mut [u8], - device_buffer: &'b mut [u8], - ) -> Poll>>>> { + ip4_buffer: &'b1 mut [u8], + ip6_bffer: &'b1 mut [u8], + device_buffer: &'b2 mut [u8], + ) -> Poll>>>> { if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { - return Poll::Ready(Ok(Input::Network(network))); + return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size)))); } ready!(self.sockets.poll_flush(cx))?; @@ -121,3 +122,14 @@ impl Io { Ok(()) } } + +fn is_max_wg_packet_size(d: &DatagramIn) -> bool { + let len = d.packet.len(); + if len > BUF_SIZE { + tracing::debug!(from = %d.from, %len, "Dropping too large datagram (max allowed: {BUF_SIZE} bytes)"); + + return false; + } + + true +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 06873090d..a3b9c5ac2 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -50,6 +50,15 @@ const REALM: &str = "firezone"; /// Thus, it is chosen as a safe, upper boundary that is not meant to be hit (and thus doesn't affect performance), yet acts as a safe guard, just in case. const MAX_EVENTLOOP_ITERS: u32 = 5000; +/// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag) +const WG_OVERHEAD: usize = 32; +/// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4). +const NAT46_OVERHEAD: usize = 20; +/// TURN's data channels have a 4 byte overhead. +const DATA_CHANNEL_OVERHEAD: usize = 4; + +const BUF_SIZE: usize = DEFAULT_MTU + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD; + pub type GatewayTunnel = Tunnel; pub type ClientTunnel = Tunnel; @@ -73,10 +82,8 @@ pub struct Tunnel { ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, - // We need an extra 16 bytes on top of the MTU for write_buf since boringtun copies the extra AEAD tag before decrypting it - write_buf: Box<[u8; DEFAULT_MTU + 16 + 20]>, - // We have 20 extra bytes to be able to convert between ipv4 and ipv6 - device_read_buf: Box<[u8; DEFAULT_MTU + 20]>, + /// Buffer for processing a single IP packet. + packet_buffer: Box<[u8; BUF_SIZE]>, } impl ClientTunnel { @@ -89,10 +96,9 @@ impl ClientTunnel { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, role_state: ClientState::new(private_key, known_hosts, rand::random()), - write_buf: Box::new([0u8; DEFAULT_MTU + 16 + 20]), + packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]), }) } @@ -127,7 +133,7 @@ impl ClientTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), + self.packet_buffer.as_mut(), )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout); @@ -149,7 +155,7 @@ impl ClientTunnel { received.from, received.packet, std::time::Instant::now(), - self.write_buf.as_mut(), + self.packet_buffer.as_mut(), ) else { continue; }; @@ -180,10 +186,9 @@ impl GatewayTunnel { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, role_state: GatewayState::new(private_key, rand::random()), - write_buf: Box::new([0u8; DEFAULT_MTU + 20 + 16]), + packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]), }) } @@ -211,7 +216,7 @@ impl GatewayTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), + self.packet_buffer.as_mut(), )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout, Utc::now()); @@ -236,7 +241,7 @@ impl GatewayTunnel { received.from, received.packet, std::time::Instant::now(), - self.write_buf.as_mut(), + self.packet_buffer.as_mut(), ) else { continue; };