chore(connlib): reduce buffer sizes (#6360)

Currently, `snownet` allocates a 65KB buffer per connection as a
scratch-space for encrypting packets. 65KB is the theoretical limit of a
UDP packet. In practice, the largest UDP packets we send are 1336 bytes
due to the MTU of 1280 set on our TUN interface and various overheads
for WG, TURN channels and NAT46.

Thus, it is unnecessary to allocate such a large buffer per connection.
For gateways with many connections, reducing these buffers results in a
smaller memory footprint.

Additionally, any UDP packets larger than this buffer could be an
indicator of a DoS attack and we can thus drop them without processing.
A legitimate client / gateway will never send a packet larger than that.
This commit is contained in:
Thomas Eizinger
2024-08-20 23:17:55 +01:00
committed by GitHub
parent 95ec1871e7
commit 99aa973db4
6 changed files with 49 additions and 30 deletions

View File

@@ -40,8 +40,6 @@ const CANDIDATE_TIMEOUT: Duration = Duration::from_secs(10);
/// How long we will at most wait for an [`Answer`] from the remote.
pub const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(20);
const MAX_UDP_SIZE: usize = (1 << 16) - 1;
/// Manages a set of wireguard connections for a server.
pub type ServerNode<TId, RId> = Node<Server, TId, RId>;
/// Manages a set of wireguard connections for a client.
@@ -96,7 +94,7 @@ pub struct Node<T, TId, RId> {
connections: Connections<TId, RId>,
pending_events: VecDeque<Event<TId>>,
buffer: Box<[u8; MAX_UDP_SIZE]>,
buffer: Vec<u8>,
stats: NodeStats,
@@ -127,7 +125,7 @@ where
TId: Eq + Hash + Copy + Ord + fmt::Display,
RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display,
{
pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self {
pub fn new(private_key: StaticSecret, buf_size: usize, seed: [u8; 32]) -> Self {
let public_key = &(&private_key).into();
Self {
rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key.
@@ -141,7 +139,7 @@ where
buffered_transmits: VecDeque::default(),
next_rate_limiter_reset: None,
pending_events: VecDeque::default(),
buffer: Box::new([0u8; MAX_UDP_SIZE]),
buffer: vec![0; buf_size],
allocations: Default::default(),
connections: Default::default(),
stats: Default::default(),
@@ -372,7 +370,7 @@ where
tracing::warn!(%relay, "No allocation");
return Ok(None);
};
let packet = &mut self.buffer.as_mut()[..packet_end];
let packet = &mut self.buffer[..packet_end];
let Some(transmit) = allocation.encode_to_borrowed_transmit(peer, packet, now)
else {
@@ -574,7 +572,7 @@ where
),
next_timer_update: now,
stats: Default::default(),
buffer: Box::new([0u8; MAX_UDP_SIZE]),
buffer: vec![0; self.buffer.capacity()],
intent_sent_at,
signalling_completed_at: now,
remote_pub_key: remote,
@@ -1346,7 +1344,7 @@ struct Connection<RId> {
intent_sent_at: Instant,
signalling_completed_at: Instant,
buffer: Box<[u8; MAX_UDP_SIZE]>,
buffer: Vec<u8>,
last_outgoing: Instant,
last_incoming: Instant,

View File

@@ -75,12 +75,14 @@ fn only_generate_candidate_event_after_answer() {
let mut alice = ClientNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
0,
rand::random(),
);
alice.add_local_host_candidate(local_candidate).unwrap();
let mut bob = ServerNode::<u64, u64>::new(
StaticSecret::random_from_rng(rand::thread_rng()),
0,
rand::random(),
);
@@ -108,10 +110,12 @@ fn only_generate_candidate_event_after_answer() {
fn alice_and_bob() -> (ClientNode<u64, u64>, ServerNode<u64, u64>) {
let alice = ClientNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
0,
rand::random(),
);
let bob = ServerNode::new(
StaticSecret::random_from_rng(rand::thread_rng()),
0,
rand::random(),
);

View File

@@ -1,6 +1,6 @@
use crate::dns;
use crate::dns::StubResolver;
use crate::peer_store::PeerStore;
use crate::{dns, BUF_SIZE};
use anyhow::Context;
use bimap::BiMap;
use connlib_shared::callbacks::Status;
@@ -309,7 +309,7 @@ impl ClientState {
buffered_events: Default::default(),
interface_config: Default::default(),
buffered_packets: Default::default(),
node: ClientNode::new(private_key.into(), seed),
node: ClientNode::new(private_key.into(), BUF_SIZE, seed),
system_resolvers: Default::default(),
sites_status: Default::default(),
gateways_site: Default::default(),

View File

@@ -1,7 +1,7 @@
use crate::peer::ClientOnGateway;
use crate::peer_store::PeerStore;
use crate::utils::earliest;
use crate::{GatewayEvent, GatewayTunnel};
use crate::{GatewayEvent, GatewayTunnel, BUF_SIZE};
use anyhow::{bail, Context};
use boringtun::x25519::PublicKey;
use chrono::{DateTime, Utc};
@@ -146,7 +146,7 @@ impl GatewayState {
pub(crate) fn new(private_key: impl Into<StaticSecret>, seed: [u8; 32]) -> Self {
Self {
peers: Default::default(),
node: ServerNode::new(private_key.into(), seed),
node: ServerNode::new(private_key.into(), BUF_SIZE, seed),
next_expiry_resources_check: Default::default(),
buffered_events: VecDeque::default(),
}

View File

@@ -1,6 +1,7 @@
use crate::{
device_channel::Device,
sockets::{NoInterfaces, Sockets},
BUF_SIZE,
};
use futures_util::FutureExt as _;
use ip_packet::{IpPacket, MutableIpPacket};
@@ -54,15 +55,15 @@ impl Io {
})
}
pub fn poll<'b>(
pub fn poll<'b1, 'b2>(
&mut self,
cx: &mut Context<'_>,
ip4_buffer: &'b mut [u8],
ip6_bffer: &'b mut [u8],
device_buffer: &'b mut [u8],
) -> Poll<io::Result<Input<'b, impl Iterator<Item = DatagramIn<'b>>>>> {
ip4_buffer: &'b1 mut [u8],
ip6_bffer: &'b1 mut [u8],
device_buffer: &'b2 mut [u8],
) -> Poll<io::Result<Input<'b2, impl Iterator<Item = DatagramIn<'b1>>>>> {
if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? {
return Poll::Ready(Ok(Input::Network(network)));
return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size))));
}
ready!(self.sockets.poll_flush(cx))?;
@@ -121,3 +122,14 @@ impl Io {
Ok(())
}
}
fn is_max_wg_packet_size(d: &DatagramIn) -> bool {
let len = d.packet.len();
if len > BUF_SIZE {
tracing::debug!(from = %d.from, %len, "Dropping too large datagram (max allowed: {BUF_SIZE} bytes)");
return false;
}
true
}

View File

@@ -50,6 +50,15 @@ const REALM: &str = "firezone";
/// Thus, it is chosen as a safe, upper boundary that is not meant to be hit (and thus doesn't affect performance), yet acts as a safe guard, just in case.
const MAX_EVENTLOOP_ITERS: u32 = 5000;
/// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag)
const WG_OVERHEAD: usize = 32;
/// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4).
const NAT46_OVERHEAD: usize = 20;
/// TURN's data channels have a 4 byte overhead.
const DATA_CHANNEL_OVERHEAD: usize = 4;
const BUF_SIZE: usize = DEFAULT_MTU + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD;
pub type GatewayTunnel = Tunnel<GatewayState>;
pub type ClientTunnel = Tunnel<ClientState>;
@@ -73,10 +82,8 @@ pub struct Tunnel<TRoleState> {
ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>,
ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>,
// We need an extra 16 bytes on top of the MTU for write_buf since boringtun copies the extra AEAD tag before decrypting it
write_buf: Box<[u8; DEFAULT_MTU + 16 + 20]>,
// We have 20 extra bytes to be able to convert between ipv4 and ipv6
device_read_buf: Box<[u8; DEFAULT_MTU + 20]>,
/// Buffer for processing a single IP packet.
packet_buffer: Box<[u8; BUF_SIZE]>,
}
impl ClientTunnel {
@@ -89,10 +96,9 @@ impl ClientTunnel {
Ok(Self {
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
role_state: ClientState::new(private_key, known_hosts, rand::random()),
write_buf: Box::new([0u8; DEFAULT_MTU + 16 + 20]),
packet_buffer: Box::new([0u8; BUF_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]),
})
}
@@ -127,7 +133,7 @@ impl ClientTunnel {
cx,
self.ip4_read_buf.as_mut(),
self.ip6_read_buf.as_mut(),
self.device_read_buf.as_mut(),
self.packet_buffer.as_mut(),
)? {
Poll::Ready(io::Input::Timeout(timeout)) => {
self.role_state.handle_timeout(timeout);
@@ -149,7 +155,7 @@ impl ClientTunnel {
received.from,
received.packet,
std::time::Instant::now(),
self.write_buf.as_mut(),
self.packet_buffer.as_mut(),
) else {
continue;
};
@@ -180,10 +186,9 @@ impl GatewayTunnel {
Ok(Self {
io: Io::new(tcp_socket_factory, udp_socket_factory)?,
role_state: GatewayState::new(private_key, rand::random()),
write_buf: Box::new([0u8; DEFAULT_MTU + 20 + 16]),
packet_buffer: Box::new([0u8; BUF_SIZE]),
ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]),
device_read_buf: Box::new([0u8; DEFAULT_MTU + 20]),
})
}
@@ -211,7 +216,7 @@ impl GatewayTunnel {
cx,
self.ip4_read_buf.as_mut(),
self.ip6_read_buf.as_mut(),
self.device_read_buf.as_mut(),
self.packet_buffer.as_mut(),
)? {
Poll::Ready(io::Input::Timeout(timeout)) => {
self.role_state.handle_timeout(timeout, Utc::now());
@@ -236,7 +241,7 @@ impl GatewayTunnel {
received.from,
received.packet,
std::time::Instant::now(),
self.write_buf.as_mut(),
self.packet_buffer.as_mut(),
) else {
continue;
};