diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index 9e232a887..5683f3f2c 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -40,8 +40,6 @@ const CANDIDATE_TIMEOUT: Duration = Duration::from_secs(10); /// How long we will at most wait for an [`Answer`] from the remote. pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(20); -const MAX_UDP_SIZE: usize = (1 << 16) - 1; - /// Manages a set of wireguard connections for a server. pub type ServerNode = Node; /// Manages a set of wireguard connections for a client. @@ -96,7 +94,7 @@ pub struct Node { connections: Connections, pending_events: VecDeque>, - buffer: Box<[u8; MAX_UDP_SIZE]>, + buffer: Vec, stats: NodeStats, @@ -127,7 +125,7 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { + pub fn new(private_key: StaticSecret, buf_size: usize, seed: [u8; 32]) -> Self { let public_key = &(&private_key).into(); Self { rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. @@ -141,7 +139,7 @@ where buffered_transmits: VecDeque::default(), next_rate_limiter_reset: None, pending_events: VecDeque::default(), - buffer: Box::new([0u8; MAX_UDP_SIZE]), + buffer: vec![0; buf_size], allocations: Default::default(), connections: Default::default(), stats: Default::default(), @@ -372,7 +370,7 @@ where tracing::warn!(%relay, "No allocation"); return Ok(None); }; - let packet = &mut self.buffer.as_mut()[..packet_end]; + let packet = &mut self.buffer[..packet_end]; let Some(transmit) = allocation.encode_to_borrowed_transmit(peer, packet, now) else { @@ -574,7 +572,7 @@ where ), next_timer_update: now, stats: Default::default(), - buffer: Box::new([0u8; MAX_UDP_SIZE]), + buffer: vec![0; self.buffer.capacity()], intent_sent_at, signalling_completed_at: now, remote_pub_key: remote, @@ -1346,7 +1344,7 @@ struct Connection { intent_sent_at: Instant, signalling_completed_at: Instant, - buffer: Box<[u8; MAX_UDP_SIZE]>, + buffer: Vec, last_outgoing: Instant, last_incoming: Instant, diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 343ad6a3b..16b6ec6c9 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -75,12 +75,14 @@ fn only_generate_candidate_event_after_answer() { let mut alice = ClientNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); alice.add_local_host_candidate(local_candidate).unwrap(); let mut bob = ServerNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); @@ -108,10 +110,12 @@ fn only_generate_candidate_event_after_answer() { fn alice_and_bob() -> (ClientNode, ServerNode) { let alice = ClientNode::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); let bob = ServerNode::new( StaticSecret::random_from_rng(rand::thread_rng()), + 0, rand::random(), ); diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index e8b0c9c55..73f365f5e 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,6 +1,6 @@ -use crate::dns; use crate::dns::StubResolver; use crate::peer_store::PeerStore; +use crate::{dns, BUF_SIZE}; use anyhow::Context; use bimap::BiMap; use connlib_shared::callbacks::Status; @@ -309,7 +309,7 @@ impl ClientState { buffered_events: Default::default(), interface_config: Default::default(), buffered_packets: Default::default(), - node: ClientNode::new(private_key.into(), seed), + node: ClientNode::new(private_key.into(), BUF_SIZE, seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 22d21c2b8..d0a727e65 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,7 +1,7 @@ use crate::peer::ClientOnGateway; use crate::peer_store::PeerStore; use crate::utils::earliest; -use crate::{GatewayEvent, GatewayTunnel}; +use crate::{GatewayEvent, GatewayTunnel, BUF_SIZE}; use anyhow::{bail, Context}; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; @@ -146,7 +146,7 @@ impl GatewayState { pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into(), seed), + node: ServerNode::new(private_key.into(), BUF_SIZE, seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index b5a3ea89f..a9221bebe 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,7 @@ use crate::{ device_channel::Device, sockets::{NoInterfaces, Sockets}, + BUF_SIZE, }; use futures_util::FutureExt as _; use ip_packet::{IpPacket, MutableIpPacket}; @@ -54,15 +55,15 @@ impl Io { }) } - pub fn poll<'b>( + pub fn poll<'b1, 'b2>( &mut self, cx: &mut Context<'_>, - ip4_buffer: &'b mut [u8], - ip6_bffer: &'b mut [u8], - device_buffer: &'b mut [u8], - ) -> Poll>>>> { + ip4_buffer: &'b1 mut [u8], + ip6_bffer: &'b1 mut [u8], + device_buffer: &'b2 mut [u8], + ) -> Poll>>>> { if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { - return Poll::Ready(Ok(Input::Network(network))); + return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size)))); } ready!(self.sockets.poll_flush(cx))?; @@ -121,3 +122,14 @@ impl Io { Ok(()) } } + +fn is_max_wg_packet_size(d: &DatagramIn) -> bool { + let len = d.packet.len(); + if len > BUF_SIZE { + tracing::debug!(from = %d.from, %len, "Dropping too large datagram (max allowed: {BUF_SIZE} bytes)"); + + return false; + } + + true +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 06873090d..a3b9c5ac2 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -50,6 +50,15 @@ const REALM: &str = "firezone"; /// Thus, it is chosen as a safe, upper boundary that is not meant to be hit (and thus doesn't affect performance), yet acts as a safe guard, just in case. const MAX_EVENTLOOP_ITERS: u32 = 5000; +/// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag) +const WG_OVERHEAD: usize = 32; +/// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4). +const NAT46_OVERHEAD: usize = 20; +/// TURN's data channels have a 4 byte overhead. +const DATA_CHANNEL_OVERHEAD: usize = 4; + +const BUF_SIZE: usize = DEFAULT_MTU + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD; + pub type GatewayTunnel = Tunnel; pub type ClientTunnel = Tunnel; @@ -73,10 +82,8 @@ pub struct Tunnel { ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, - // We need an extra 16 bytes on top of the MTU for write_buf since boringtun copies the extra AEAD tag before decrypting it - write_buf: Box<[u8; DEFAULT_MTU + 16 + 20]>, - // We have 20 extra bytes to be able to convert between ipv4 and ipv6 - device_read_buf: Box<[u8; DEFAULT_MTU + 20]>, + /// Buffer for processing a single IP packet. + packet_buffer: Box<[u8; BUF_SIZE]>, } impl ClientTunnel { @@ -89,10 +96,9 @@ impl ClientTunnel { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, role_state: ClientState::new(private_key, known_hosts, rand::random()), - write_buf: Box::new([0u8; DEFAULT_MTU + 16 + 20]), + packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]), }) } @@ -127,7 +133,7 @@ impl ClientTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), + self.packet_buffer.as_mut(), )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout); @@ -149,7 +155,7 @@ impl ClientTunnel { received.from, received.packet, std::time::Instant::now(), - self.write_buf.as_mut(), + self.packet_buffer.as_mut(), ) else { continue; }; @@ -180,10 +186,9 @@ impl GatewayTunnel { Ok(Self { io: Io::new(tcp_socket_factory, udp_socket_factory)?, role_state: GatewayState::new(private_key, rand::random()), - write_buf: Box::new([0u8; DEFAULT_MTU + 20 + 16]), + packet_buffer: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]), }) } @@ -211,7 +216,7 @@ impl GatewayTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), + self.packet_buffer.as_mut(), )? { Poll::Ready(io::Input::Timeout(timeout)) => { self.role_state.handle_timeout(timeout, Utc::now()); @@ -236,7 +241,7 @@ impl GatewayTunnel { received.from, received.packet, std::time::Instant::now(), - self.write_buf.as_mut(), + self.packet_buffer.as_mut(), ) else { continue; };