From 29bc276bf207ec0b3a8fd06070dd091d6db46741 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Thu, 26 Sep 2024 13:03:35 +1000 Subject: [PATCH] refactor(connlib): parallelise TUN operations (#6673) Currently, `connlib` is entirely single-threaded. This allows us to reuse a single buffer for processing IP packets and makes reasoning of the packet processing code very simple. Being single-threaded also means we can only make use of a single CPU core and all operations have to be sequential. Analyzing `connlib` using `perf` shows that we spend 26% of our CPU time writing packets to the TUN interface [0]. Because we are single-threaded, `connlib` cannot do anything else during this time. If we could offload the writing of these packets to a different thread, `connlib` could already process the next packet while the current one is writing. Packets that we send to the TUN interface arrived as an encrypted WG packet over UDP and get decrypted into a - currently - shared buffer. Moving the writing to a different thread implies that we have to have more of these buffer that the next packet(s) can be decrypted into. To avoid IP fragmentation, we set the maximum IP MTU to 1280 bytes on the TUN interface. That actually isn't very big and easily fits into a stackframe. The default stack size for threads is 2MB [1]. Instead of creating more buffers and cycling through them, we can also simply stack-allocate our IP packets. This incurs some overhead from copying packets but it is only ~3.5% [2] (This was measured without a separate thread). With stack-allocated packets, almost all lifetime-annotations go away which in itself is already a welcome ergonomics boost. Stack-allocated packets also means we can simply spawn a new thread for the packet processing. This thread is connected with two channel to connlib's main thread. The capacity of 1000 packets will at most consume an additional 3.5 MB of memory which is fine even on our most-constrained devices such as iOS. [0]: https://share.firefox.dev/3z78CzD [1]: https://doc.rust-lang.org/std/thread/#stack-size [2]: https://share.firefox.dev/3Bf4zla Resolves: #6653. Resolves: #5541. --- rust/Cargo.lock | 2 + rust/bin-shared/benches/tunnel.rs | 8 +- rust/connlib/shared/src/lib.rs | 2 - rust/connlib/snownet/src/node.rs | 36 +-- rust/connlib/snownet/tests/lib.rs | 4 - .../tunnel/proptest-regressions/tests.txt | 1 + rust/connlib/tunnel/src/client.rs | 40 ++- rust/connlib/tunnel/src/device_channel.rs | 15 +- rust/connlib/tunnel/src/dns.rs | 6 +- rust/connlib/tunnel/src/gateway.rs | 14 +- rust/connlib/tunnel/src/io.rs | 169 +++++++++-- rust/connlib/tunnel/src/lib.rs | 34 +-- rust/connlib/tunnel/src/peer.rs | 22 +- rust/connlib/tunnel/src/peer/nat_table.rs | 6 +- rust/connlib/tunnel/src/tests/assertions.rs | 16 +- rust/connlib/tunnel/src/tests/sim_client.rs | 26 +- rust/connlib/tunnel/src/tests/sim_gateway.rs | 21 +- rust/gateway/Cargo.toml | 1 + rust/gateway/src/main.rs | 4 +- rust/headless-client/Cargo.toml | 3 +- rust/headless-client/src/ipc_service.rs | 4 +- rust/headless-client/src/main.rs | 4 +- rust/ip-packet/src/lib.rs | 274 ++++++++---------- rust/ip-packet/src/make.rs | 21 +- rust/ip-packet/src/nat46.rs | 12 +- rust/ip-packet/src/proptest.rs | 8 +- rust/ip-packet/src/proptests.rs | 24 +- 27 files changed, 408 insertions(+), 369 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8cff35fbc..c58b081bc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2268,6 +2268,7 @@ dependencies = [ "firezone-tunnel", "futures", "futures-bounded", + "ip-packet", "ip_network", "libc", "phoenix-channel", @@ -2380,6 +2381,7 @@ dependencies = [ "firezone-logging", "futures", "humantime", + "ip-packet", "ip_network", "ipconfig", "known-folders", diff --git a/rust/bin-shared/benches/tunnel.rs b/rust/bin-shared/benches/tunnel.rs index d58626247..51fee12ec 100644 --- a/rust/bin-shared/benches/tunnel.rs +++ b/rust/bin-shared/benches/tunnel.rs @@ -26,7 +26,7 @@ mod platform { mod platform { use anyhow::Result; use firezone_bin_shared::TunDeviceManager; - use ip_packet::IpPacket; + use ip_packet::{IpPacket, IpPacketBuf}; use std::{ future::poll_fn, net::{Ipv4Addr, Ipv6Addr}, @@ -65,10 +65,10 @@ mod platform { let mut response_pkt = None; let mut time_spent = Duration::from_millis(0); loop { - let mut req_buf = [0u8; MTU + 20]; - poll_fn(|cx| tun.poll_read(&mut req_buf[20..], cx)).await?; + let mut req_buf = IpPacketBuf::new(); + let n = poll_fn(|cx| tun.poll_read(req_buf.buf(), cx)).await?; let start = Instant::now(); - let original_pkt = IpPacket::new(&mut req_buf).unwrap(); + let original_pkt = IpPacket::new(req_buf, n).unwrap(); let Some(original_udp) = original_pkt.as_udp() else { continue; }; diff --git a/rust/connlib/shared/src/lib.rs b/rust/connlib/shared/src/lib.rs index ffd847bf5..25f51e6fa 100644 --- a/rust/connlib/shared/src/lib.rs +++ b/rust/connlib/shared/src/lib.rs @@ -12,8 +12,6 @@ pub use phoenix_channel::{LoginUrl, LoginUrlError}; pub type DomainName = domain::base::Name>; -pub const DEFAULT_MTU: usize = 1280; - const LIB_NAME: &str = "connlib"; pub fn get_user_agent(os_version_override: Option, app_version: &str) -> String { diff --git a/rust/connlib/snownet/src/node.rs b/rust/connlib/snownet/src/node.rs index b6c8be09f..35cc5b35c 100644 --- a/rust/connlib/snownet/src/node.rs +++ b/rust/connlib/snownet/src/node.rs @@ -9,7 +9,7 @@ use boringtun::x25519::PublicKey; use boringtun::{noise::rate_limiter::RateLimiter, x25519::StaticSecret}; use core::fmt; use hex_display::HexDisplayExt; -use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket}; +use ip_packet::{ConvertibleIpv4Packet, ConvertibleIpv6Packet, IpPacket, IpPacketBuf}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::{random, SeedableRng}; @@ -92,8 +92,6 @@ pub struct Node { connections: Connections, pending_events: VecDeque>, - buffer: Vec, - stats: NodeStats, marker: PhantomData, @@ -123,7 +121,7 @@ where TId: Eq + Hash + Copy + Ord + fmt::Display, RId: Copy + Eq + Hash + PartialEq + Ord + fmt::Debug + fmt::Display, { - pub fn new(private_key: StaticSecret, buf_size: usize, seed: [u8; 32]) -> Self { + pub fn new(private_key: StaticSecret, seed: [u8; 32]) -> Self { let public_key = &(&private_key).into(); Self { rng: StdRng::from_seed(seed), // TODO: Use this seed for private key too. Requires refactoring of how we generate the login-url because that one needs to know the public key. @@ -137,7 +135,6 @@ where buffered_transmits: VecDeque::default(), next_rate_limiter_reset: None, pending_events: VecDeque::default(), - buffer: vec![0; buf_size], allocations: Default::default(), connections: Default::default(), stats: Default::default(), @@ -285,14 +282,13 @@ where /// - `Ok(None)` if the packet was handled internally, for example, a response from a TURN server. /// - `Ok(Some)` if the packet was an encrypted wireguard packet from a peer. /// The `Option` contains the connection on which the packet was decrypted. - pub fn decapsulate<'b>( + pub fn decapsulate( &mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant, - buffer: &'b mut [u8], - ) -> Result)>, Error> { + ) -> Result, Error> { self.add_local_as_host_candidate(local)?; let (from, packet, relayed) = match self.allocations_try_handle(from, local, packet, now) { @@ -309,7 +305,7 @@ where ControlFlow::Break(Err(e)) => return Err(e), }; - let (id, packet) = match self.connections_try_handle(from, packet, buffer, now) { + let (id, packet) = match self.connections_try_handle(from, packet, now) { ControlFlow::Continue(c) => c, ControlFlow::Break(Ok(())) => return Ok(None), ControlFlow::Break(Err(e)) => return Err(e), @@ -326,7 +322,7 @@ where pub fn encapsulate( &mut self, connection: TId, - packet: IpPacket<'_>, + packet: IpPacket, now: Instant, buffer: &mut EncryptBuffer, ) -> Result, Error> { @@ -572,7 +568,7 @@ where ), next_timer_update: now, stats: Default::default(), - buffer: vec![0; self.buffer.capacity()], + buffer: vec![0; ip_packet::MAX_DATAGRAM_PAYLOAD], intent_sent_at, signalling_completed_at: now, remote_pub_key: remote, @@ -707,13 +703,12 @@ where } #[must_use] - fn connections_try_handle<'b>( + fn connections_try_handle( &mut self, from: SocketAddr, packet: &[u8], - buffer: &'b mut [u8], now: Instant, - ) -> ControlFlow, (TId, IpPacket<'b>)> { + ) -> ControlFlow, (TId, IpPacket)> { for (cid, conn) in self.connections.iter_established_mut() { if !conn.accepts(&from) { continue; @@ -723,7 +718,6 @@ where let control_flow = conn.decapsulate( packet, - buffer, &mut self.allocations, &mut self.buffered_transmits, now, @@ -1702,17 +1696,17 @@ where Ok(Some(&buffer[..len])) } - fn decapsulate<'b>( + fn decapsulate( &mut self, packet: &[u8], - buffer: &'b mut [u8], allocations: &mut BTreeMap, transmits: &mut VecDeque>, now: Instant, - ) -> ControlFlow, IpPacket<'b>> { + ) -> ControlFlow, IpPacket> { let _guard = self.span.enter(); + let mut ip_packet = IpPacketBuf::new(); - let control_flow = match self.tunnel.decapsulate(None, packet, &mut buffer[20..]) { + let control_flow = match self.tunnel.decapsulate(None, packet, ip_packet.buf()) { TunnResult::Done => ControlFlow::Break(Ok(())), TunnResult::Err(e) => ControlFlow::Break(Err(Error::Decapsulate(e))), @@ -1722,7 +1716,7 @@ where // Thus, the caller can query whatever data they'd like, not just the source IP so we don't return it in addition. TunnResult::WriteToTunnelV4(packet, ip) => { let packet_len = packet.len(); - let ipv4_packet = ConvertibleIpv4Packet::new(&mut buffer[..(packet_len + 20)]) + let ipv4_packet = ConvertibleIpv4Packet::new(ip_packet, packet_len) .expect("boringtun verifies validity"); debug_assert_eq!(ipv4_packet.get_source(), ip); @@ -1733,7 +1727,7 @@ where // for ipv6 we just need this to convince the borrow-checker that `packet`'s lifetime isn't `'b`, otherwise it's taken // as `'b` for all branches. let packet_len = packet.len(); - let ipv6_packet = ConvertibleIpv6Packet::new(&mut buffer[20..(packet_len + 20)]) + let ipv6_packet = ConvertibleIpv6Packet::new(ip_packet, packet_len) .expect("boringtun verifies validity"); debug_assert_eq!(ipv6_packet.get_source(), ip); diff --git a/rust/connlib/snownet/tests/lib.rs b/rust/connlib/snownet/tests/lib.rs index 16b6ec6c9..343ad6a3b 100644 --- a/rust/connlib/snownet/tests/lib.rs +++ b/rust/connlib/snownet/tests/lib.rs @@ -75,14 +75,12 @@ fn only_generate_candidate_event_after_answer() { let mut alice = ClientNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), - 0, rand::random(), ); alice.add_local_host_candidate(local_candidate).unwrap(); let mut bob = ServerNode::::new( StaticSecret::random_from_rng(rand::thread_rng()), - 0, rand::random(), ); @@ -110,12 +108,10 @@ fn only_generate_candidate_event_after_answer() { fn alice_and_bob() -> (ClientNode, ServerNode) { let alice = ClientNode::new( StaticSecret::random_from_rng(rand::thread_rng()), - 0, rand::random(), ); let bob = ServerNode::new( StaticSecret::random_from_rng(rand::thread_rng()), - 0, rand::random(), ); diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index e7c80f255..80b8e45b8 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -105,3 +105,4 @@ cc 44a16aa3dad95769d5fc1ff907952af31126b5ae7e5298dca72046bb2a98205f # shrinks to cc e83b5929e8d9f3d26be4d341c0d3cce206d4c137855de5668c27a53052546681 cc 4f5a2f6c9162963e20d82f9e6dfddca6992401ae65287776b1c3a736f7e1f1f7 cc f0ffe4d3c6a019810f4dc87fb0f741c9b76d5756c0077edeee657ec3a5193df9 +cc aa50269a0b4c691fd00812648f5d26853e3f2581939c1c3d35c4aff2811ee2a4 # shrinks to (ReferenceState { client: Host { inner: RefClient { id: aa161b58-2acd-0f88-a1ff-6af707903c09, key: PrivateKey("9fd04b63de8d3ae6e83d039369bf29de01fab81f31cafe5732e186cdd9ca5118"), known_hosts: {"api.firez.one": [2bf4:75e5:1edd:1eca:8e47:1e7f:7706:cdce]}, tunnel_ip4: 100.64.0.1, tunnel_ip6: fd00:2021:1111::, ipv4_routes: {Ipv4Network { network_address: 100.96.0.0, netmask: 11 }, Ipv4Network { network_address: 100.100.111.0, netmask: 24 }}, ipv6_routes: {Ipv6Network { network_address: fd00:2021:1111:8000::, netmask: 107 }, Ipv6Network { network_address: fd00:2021:1111:8000:100:100:111:0, netmask: 120 }} }, ip4: None, ip6: Some(2001:db80::40), default_port: 13004, latency: 285ms }, gateways: {1b82ddd6-902f-d519-394f-92ae107fec08: Host { inner: RefGateway { key: PrivateKey("ce2232673aa01c948004881d3c86ed13212393f8aaea40a7e853c494a645189e") }, ip4: Some(203.0.113.25), ip6: Some(2001:db80::b), default_port: 39793, latency: 125ms }, 3476f787-9f99-9cfb-845d-6bc3068085b8: Host { inner: RefGateway { key: PrivateKey("ed9647b905876947cbc98cfbce2362215b76a76fbeec72eea179280a0685b0c5") }, ip4: Some(203.0.113.90), ip6: Some(2001:db80::37), default_port: 42103, latency: 82ms }, 3a154caa-8804-7461-9a86-81614e2bafe0: Host { inner: RefGateway { key: PrivateKey("8659e1d83a85b5f2bc8d6321af680364f1969f1d51ab2c1f2307ed8573afe751") }, ip4: Some(203.0.113.4), ip6: Some(2001:db80::42), default_port: 16650, latency: 124ms }, e16813f6-f291-9c92-7385-2ddf57598b07: Host { inner: RefGateway { key: PrivateKey("56d997c3583de22abfbe5cbf2f30ebe1883823f757dec45037db25c577a9fd59") }, ip4: Some(203.0.113.92), ip6: Some(2001:db80::32), default_port: 14859, latency: 38ms }}, relays: {5a54264a-0e72-0e93-897d-8ddd92da7424: Host { inner: 5725775471353133246, ip4: Some(203.0.113.97), ip6: Some(2001:db80::10), default_port: 3478, latency: 45ms }}, dns_servers: {1AEA7C457539FE1D6979A3E69BA2117D: Host { inner: RefDns, ip4: Some(217.255.130.74), ip6: None, default_port: 53, latency: 36ms }, 33ECD04A65AC03B0DBBE5949AAC04019: Host { inner: RefDns, ip4: Some(225.49.6.141), ip6: None, default_port: 53, latency: 32ms }, A173FB62F64BEDE5A4649C9E15363674: Host { inner: RefDns, ip4: Some(246.68.236.104), ip6: None, default_port: 53, latency: 23ms }, F59A62324FA1DB3C5B3D82765D679D3F: Host { inner: RefDns, ip4: None, ip6: Some(::ffff:115.190.101.132), default_port: 53, latency: 46ms }}, portal: StubPortal { gateways_by_site: {2475243d-3a89-de38-9b22-a6a89f161466: {1b82ddd6-902f-d519-394f-92ae107fec08}, 24ca4825-0e32-2e43-fb5f-a00b66c729b0: {3476f787-9f99-9cfb-845d-6bc3068085b8}, db992e5e-afbb-a6ac-d143-8207a10e66cc: {3a154caa-8804-7461-9a86-81614e2bafe0, e16813f6-f291-9c92-7385-2ddf57598b07}}, cidr_resources: {14ed97c5-a780-eef3-a0f6-874f61f737a8: ResourceDescriptionCidr { id: 14ed97c5-a780-eef3-a0f6-874f61f737a8, address: V4(Ipv4Network { network_address: 127.0.0.0, netmask: 25 }), name: "bvujezlta", address_description: None, sites: [Site { id: 24ca4825-0e32-2e43-fb5f-a00b66c729b0, name: "tviv" }] }, 44bac9fa-202f-34b3-b673-3c43b1d232b6: ResourceDescriptionCidr { id: 44bac9fa-202f-34b3-b673-3c43b1d232b6, address: V6(Ipv6Network { network_address: ::ffff:127.0.0.0, netmask: 122 }), name: "xhnb", address_description: None, sites: [Site { id: 2475243d-3a89-de38-9b22-a6a89f161466, name: "xdlakl" }] }, f2b00be1-cdca-4675-2378-a7ff2ecf9f68: ResourceDescriptionCidr { id: f2b00be1-cdca-4675-2378-a7ff2ecf9f68, address: V4(Ipv4Network { network_address: 80.174.231.1, netmask: 32 }), name: "estabe", address_description: None, sites: [Site { id: 2475243d-3a89-de38-9b22-a6a89f161466, name: "xdlakl" }] }, ffe8a2ab-0a71-d306-98da-8123af83e162: ResourceDescriptionCidr { id: ffe8a2ab-0a71-d306-98da-8123af83e162, address: V6(Ipv6Network { network_address: 520e:e55f:10df:6d4a:add5:e191:3def:fbb0, netmask: 125 }), name: "nbheoefpb", address_description: None, sites: [Site { id: 24ca4825-0e32-2e43-fb5f-a00b66c729b0, name: "tviv" }] }}, dns_resources: {a58cff08-68cf-e48c-cbef-e52a472cd27b: ResourceDescriptionDns { id: a58cff08-68cf-e48c-cbef-e52a472cd27b, address: "**.efyyd.apgalu.avsm", name: "gfrz", address_description: Some("bwzhi"), sites: [Site { id: db992e5e-afbb-a6ac-d143-8207a10e66cc, name: "wewshsy" }] }, af51261a-a158-5734-3752-871a27a655e2: ResourceDescriptionDns { id: af51261a-a158-5734-3752-871a27a655e2, address: "**.rnjqja.whmjsb.wzix", name: "efsqxffipr", address_description: Some("qgylk"), sites: [Site { id: 24ca4825-0e32-2e43-fb5f-a00b66c729b0, name: "tviv" }] }}, internet_resource: ResourceDescriptionInternet { name: "Internet Resource", id: b4a20971-d535-9810-9745-6b90fed299bd, sites: [Site { id: 24ca4825-0e32-2e43-fb5f-a00b66c729b0, name: "tviv" }] } }, drop_direct_client_traffic: true, global_dns_records: {Name(etcww.efyyd.apgalu.avsm.): {2001:db80::f9}, Name(slqfgi.efyyd.apgalu.avsm.): {198.51.100.73}, Name(yhtb.efyyd.apgalu.avsm.): {198.51.100.19, 198.51.100.35, 198.51.100.159, 2001:db80::3e, 2001:db80::e0}, Name(wnysm.kll.): {0.0.0.0, 18.168.101.152, 127.0.0.1}, Name(yvkgu.pkgc.): {119.85.130.192, ::ffff:37.204.219.209, ::ffff:74.157.127.159, 2e16:188:431b:3296:d16f:5cd7:a4f7:7f0d, 87a0:47fa:103f:510c:3fc0:44fd:3168:4ad5}, Name(hml.kzais.ugedo.): {::ffff:133.198.237.87, ::ffff:162.175.130.36, 43f7:2392:8e05:e8db:4f48:c307:ec8b:712}, Name(ctxysv.rnjqja.whmjsb.wzix.): {198.51.100.118}}, network: RoutingTable { routes: {(V4(Ipv4Network { network_address: 203.0.113.4, netmask: 32 }), Gateway(3a154caa-8804-7461-9a86-81614e2bafe0)), (V4(Ipv4Network { network_address: 203.0.113.25, netmask: 32 }), Gateway(1b82ddd6-902f-d519-394f-92ae107fec08)), (V4(Ipv4Network { network_address: 203.0.113.90, netmask: 32 }), Gateway(3476f787-9f99-9cfb-845d-6bc3068085b8)), (V4(Ipv4Network { network_address: 203.0.113.92, netmask: 32 }), Gateway(e16813f6-f291-9c92-7385-2ddf57598b07)), (V4(Ipv4Network { network_address: 203.0.113.97, netmask: 32 }), Relay(5a54264a-0e72-0e93-897d-8ddd92da7424)), (V4(Ipv4Network { network_address: 217.255.130.74, netmask: 32 }), DnsServer(1AEA7C457539FE1D6979A3E69BA2117D)), (V4(Ipv4Network { network_address: 225.49.6.141, netmask: 32 }), DnsServer(33ECD04A65AC03B0DBBE5949AAC04019)), (V4(Ipv4Network { network_address: 246.68.236.104, netmask: 32 }), DnsServer(A173FB62F64BEDE5A4649C9E15363674)), (V6(Ipv6Network { network_address: ::ffff:115.190.101.132, netmask: 128 }), DnsServer(F59A62324FA1DB3C5B3D82765D679D3F)), (V6(Ipv6Network { network_address: 2001:db80::b, netmask: 128 }), Gateway(1b82ddd6-902f-d519-394f-92ae107fec08)), (V6(Ipv6Network { network_address: 2001:db80::10, netmask: 128 }), Relay(5a54264a-0e72-0e93-897d-8ddd92da7424)), (V6(Ipv6Network { network_address: 2001:db80::32, netmask: 128 }), Gateway(e16813f6-f291-9c92-7385-2ddf57598b07)), (V6(Ipv6Network { network_address: 2001:db80::37, netmask: 128 }), Gateway(3476f787-9f99-9cfb-845d-6bc3068085b8)), (V6(Ipv6Network { network_address: 2001:db80::40, netmask: 128 }), Client(aa161b58-2acd-0f88-a1ff-6af707903c09)), (V6(Ipv6Network { network_address: 2001:db80::42, netmask: 128 }), Gateway(3a154caa-8804-7461-9a86-81614e2bafe0))} } }, [ActivateResource(Dns(ResourceDescriptionDns { id: a58cff08-68cf-e48c-cbef-e52a472cd27b, address: "**.efyyd.apgalu.avsm", name: "gfrz", address_description: Some("bwzhi"), sites: [Site { id: db992e5e-afbb-a6ac-d143-8207a10e66cc, name: "wewshsy" }] })), SendDnsQueries([DnsQuery { domain: Name(slqfgi.efyyd.apgalu.avsm.), r_type: Rtype::AAAA, query_id: 8803, dns_server: [::ffff:115.190.101.132]:53 }, DnsQuery { domain: Name(yhtb.efyyd.apgalu.avsm.), r_type: Rtype::A, query_id: 32250, dns_server: [::ffff:115.190.101.132]:53 }]), SendICMPPacketToDnsResource { src: fd00:2021:1111::, dst: Name(slqfgi.efyyd.apgalu.avsm.), seq: 0, identifier: 0, payload: 0 }, SendICMPPacketToDnsResource { src: fd00:2021:1111::, dst: Name(slqfgi.efyyd.apgalu.avsm.), seq: 0, identifier: 0, payload: 0 }], None) diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index 434b6eb67..0e825d517 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -1,6 +1,6 @@ use crate::dns::StubResolver; use crate::peer_store::PeerStore; -use crate::{dns, TunConfig, BUF_SIZE}; +use crate::{dns, TunConfig}; use anyhow::Context; use bimap::BiMap; use connlib_shared::callbacks::Status; @@ -87,7 +87,7 @@ impl ClientTunnel { } pub fn set_tun(&mut self, tun: Box) { - self.io.device_mut().set_tun(tun); + self.io.set_tun(tun); } pub fn update_relays(&mut self, to_remove: BTreeSet, to_add: Vec) { @@ -252,7 +252,7 @@ pub struct ClientState { recently_connected_gateways: LruCache, buffered_events: VecDeque, - buffered_packets: VecDeque>, + buffered_packets: VecDeque, buffered_transmits: VecDeque>, } @@ -278,7 +278,7 @@ impl ClientState { buffered_events: Default::default(), tun_config: Default::default(), buffered_packets: Default::default(), - node: ClientNode::new(private_key.into(), BUF_SIZE, seed), + node: ClientNode::new(private_key.into(), seed), system_resolvers: Default::default(), sites_status: Default::default(), gateways_site: Default::default(), @@ -399,7 +399,7 @@ impl ClientState { pub(crate) fn encapsulate( &mut self, - packet: IpPacket<'_>, + packet: IpPacket, now: Instant, buffer: &mut EncryptBuffer, ) -> Option { @@ -452,14 +452,13 @@ impl ClientState { Some(transmit) } - pub(crate) fn decapsulate<'b>( + pub(crate) fn decapsulate( &mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant, - buffer: &'b mut [u8], - ) -> Option> { + ) -> Option { if let Some(response) = self.try_handle_forwarded_dns_response(from, packet) { return Some(response); }; @@ -469,7 +468,6 @@ impl ClientState { from, packet.as_ref(), now, - buffer, ) .inspect_err(|e| tracing::debug!(%local, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}")) .ok()??; @@ -619,11 +617,11 @@ impl ClientState { } /// Attempt to handle the given packet as a DNS query packet. - fn try_handle_dns_query<'a>( + fn try_handle_dns_query( &mut self, - packet: IpPacket<'a>, + packet: IpPacket, now: Instant, - ) -> ControlFlow<(), (IpPacket<'a>, IpAddr)> { + ) -> ControlFlow<(), (IpPacket, IpAddr)> { match self.stub_resolver.handle(&self.dns_mapping, &packet) { Ok(ControlFlow::Break(dns::ResolveStrategy::LocalResponse(query))) => { self.buffered_packets.push_back(query); @@ -664,11 +662,11 @@ impl ClientState { } } - fn try_handle_forwarded_dns_response<'a>( + fn try_handle_forwarded_dns_response( &mut self, from: SocketAddr, packet: &[u8], - ) -> Option> { + ) -> Option { // The sentinel DNS server shall be the source. If we don't have a sentinel DNS for this socket, it cannot be a DNS response. let saddr = *self.dns_mapping.get_by_right(&DnsServer::from(from))?; let sport = DNS_PORT; @@ -888,7 +886,7 @@ impl ClientState { self.update_dns_mapping() } - pub fn poll_packets(&mut self) -> Option> { + pub fn poll_packets(&mut self) -> Option { self.buffered_packets.pop_front() } @@ -1366,12 +1364,12 @@ fn is_definitely_not_a_resource(ip: IpAddr) -> bool { } /// In case the given packet is a DNS query, change its source IP to that of the actual DNS server. -fn maybe_mangle_dns_query_to_cidr_resource<'p>( - mut packet: IpPacket<'p>, +fn maybe_mangle_dns_query_to_cidr_resource( + mut packet: IpPacket, dns_mapping: &BiMap, mangeled_dns_queries: &mut HashMap, now: Instant, -) -> IpPacket<'p> { +) -> IpPacket { let dst = packet.destination(); let Some(srv) = dns_mapping.get_by_left(&dst) else { @@ -1395,12 +1393,12 @@ fn maybe_mangle_dns_query_to_cidr_resource<'p>( packet } -fn maybe_mangle_dns_response_from_cidr_resource<'p>( - mut packet: IpPacket<'p>, +fn maybe_mangle_dns_response_from_cidr_resource( + mut packet: IpPacket, dns_mapping: &BiMap, mangeled_dns_queries: &mut HashMap, now: Instant, -) -> IpPacket<'p> { +) -> IpPacket { let src_ip = packet.source(); let Some(udp) = packet.as_udp() else { diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index 1844f27c6..980feb56c 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -1,4 +1,4 @@ -use ip_packet::IpPacket; +use ip_packet::{IpPacket, IpPacketBuf}; use std::io; use std::task::{Context, Poll, Waker}; use tun::Tun; @@ -26,17 +26,14 @@ impl Device { } } - pub(crate) fn poll_read<'b>( - &mut self, - buf: &'b mut [u8], - cx: &mut Context<'_>, - ) -> Poll>> { + pub(crate) fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { let Some(tun) = self.tun.as_mut() else { self.waker = Some(cx.waker().clone()); return Poll::Pending; }; - let n = std::task::ready!(tun.poll_read(&mut buf[20..], cx))?; + let mut ip_packet = IpPacketBuf::new(); + let n = std::task::ready!(tun.poll_read(ip_packet.buf(), cx))?; if n == 0 { return Poll::Ready(Err(io::Error::new( @@ -45,7 +42,7 @@ impl Device { ))); } - let packet = IpPacket::new(&mut buf[..(n + 20)]).ok_or_else(|| { + let packet = IpPacket::new(ip_packet, n).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "received bytes are not an IP packet", @@ -57,7 +54,7 @@ impl Device { Poll::Ready(Ok(packet)) } - pub fn write(&self, packet: IpPacket<'_>) -> io::Result { + pub fn write(&self, packet: IpPacket) -> io::Result { tracing::trace!(target: "wire::dev::send", dst = %packet.destination(), src = %packet.source(), bytes = %packet.packet().len()); match packet { diff --git a/rust/connlib/tunnel/src/dns.rs b/rust/connlib/tunnel/src/dns.rs index d16af1ef8..79dcd8088 100644 --- a/rust/connlib/tunnel/src/dns.rs +++ b/rust/connlib/tunnel/src/dns.rs @@ -32,9 +32,13 @@ pub struct StubResolver { /// Tells the Client how to reply to a single DNS query #[derive(Debug)] +#[expect( + clippy::large_enum_variant, + reason = "We purposely don't want to allocate each IP packet." +)] pub(crate) enum ResolveStrategy { /// The query is for a Resource, we have an IP mapped already, and we can respond instantly - LocalResponse(IpPacket<'static>), + LocalResponse(IpPacket), /// The query is for a non-Resource, forward it to an upstream or system resolver. ForwardQuery { upstream: SocketAddr, diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 02f358c86..c7901a776 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -1,7 +1,7 @@ use crate::peer::ClientOnGateway; use crate::peer_store::PeerStore; use crate::utils::earliest; -use crate::{GatewayEvent, GatewayTunnel, BUF_SIZE}; +use crate::{GatewayEvent, GatewayTunnel}; use anyhow::bail; use boringtun::x25519::PublicKey; use chrono::{DateTime, Utc}; @@ -33,7 +33,7 @@ const EXPIRE_RESOURCES_INTERVAL: Duration = Duration::from_secs(1); impl GatewayTunnel { pub fn set_tun(&mut self, tun: Box) { - self.io.device_mut().set_tun(tun); + self.io.set_tun(tun); } /// Accept a connection request from a client. @@ -144,7 +144,7 @@ impl GatewayState { pub(crate) fn new(private_key: impl Into, seed: [u8; 32]) -> Self { Self { peers: Default::default(), - node: ServerNode::new(private_key.into(), BUF_SIZE, seed), + node: ServerNode::new(private_key.into(), seed), next_expiry_resources_check: Default::default(), buffered_events: VecDeque::default(), } @@ -157,7 +157,7 @@ impl GatewayState { pub(crate) fn encapsulate( &mut self, - packet: IpPacket<'_>, + packet: IpPacket, now: Instant, buffer: &mut EncryptBuffer, ) -> Option { @@ -188,20 +188,18 @@ impl GatewayState { Some(transmit) } - pub(crate) fn decapsulate<'b>( + pub(crate) fn decapsulate( &mut self, local: SocketAddr, from: SocketAddr, packet: &[u8], now: Instant, - buffer: &'b mut [u8], - ) -> Option> { + ) -> Option { let (cid, packet) = self.node.decapsulate( local, from, packet, now, - buffer, ) .inspect_err(|e| tracing::debug!(%from, num_bytes = %packet.len(), "Failed to decapsulate incoming packet: {e}")) .ok()??; diff --git a/rust/connlib/tunnel/src/io.rs b/rust/connlib/tunnel/src/io.rs index 098270fb3..3f69c478b 100644 --- a/rust/connlib/tunnel/src/io.rs +++ b/rust/connlib/tunnel/src/io.rs @@ -1,6 +1,10 @@ -use crate::{device_channel::Device, sockets::Sockets, BUF_SIZE}; +use crate::{device_channel::Device, sockets::Sockets}; +use futures::{ + future::{self, Either}, + stream, Stream, StreamExt, +}; use futures_util::FutureExt as _; -use ip_packet::IpPacket; +use ip_packet::{IpPacket, MAX_DATAGRAM_PAYLOAD}; use snownet::{EncryptBuffer, EncryptedPacket}; use socket_factory::{DatagramIn, DatagramOut, SocketFactory, TcpSocket, UdpSocket}; use std::{ @@ -10,13 +14,11 @@ use std::{ task::{ready, Context, Poll}, time::Instant, }; +use tokio::sync::mpsc; +use tun::Tun; /// Bundles together all side-effects that connlib needs to have access to. pub struct Io { - /// The TUN device offered to the user. - /// - /// This is the `tun-firezone` network interface that users see when they e.g. type `ip addr` on Linux. - device: Device, /// The UDP sockets used to send & receive packets from the network. sockets: Sockets, unwritten_packet: Option, @@ -25,14 +27,23 @@ pub struct Io { udp_socket_factory: Arc>, timeout: Option>>, + tun_tx: mpsc::Sender>, + outbound_packet_tx: mpsc::Sender, + inbound_packet_rx: mpsc::Receiver, } -pub enum Input<'a, I> { +#[expect( + clippy::large_enum_variant, + reason = "We purposely don't want to allocate each IP packet." +)] +pub enum Input { Timeout(Instant), - Device(IpPacket<'a>), + Device(IpPacket), Network(I), } +const IP_CHANNEL_SIZE: usize = 1000; + impl Io { /// Creates a new I/O abstraction /// @@ -44,8 +55,22 @@ impl Io { let mut sockets = Sockets::default(); sockets.rebind(udp_socket_factory.as_ref()); // Bind sockets on startup. Must happen within a tokio runtime context. + let (inbound_packet_tx, inbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE); + let (outbound_packet_tx, outbound_packet_rx) = mpsc::channel(IP_CHANNEL_SIZE); + let (tun_tx, tun_rx) = mpsc::channel(10); + + std::thread::spawn(|| { + futures::executor::block_on(tun_send_recv( + tun_rx, + outbound_packet_rx, + inbound_packet_tx, + )) + }); + Self { - device: Device::new(), + tun_tx, + outbound_packet_tx, + inbound_packet_rx, timeout: None, sockets, _tcp_socket_factory: tcp_socket_factory, @@ -58,21 +83,20 @@ impl Io { self.sockets.poll_has_sockets(cx) } - pub fn poll<'b1, 'b2>( + pub fn poll<'b>( &mut self, cx: &mut Context<'_>, - ip4_buffer: &'b1 mut [u8], - ip6_bffer: &'b1 mut [u8], - device_buffer: &'b2 mut [u8], + ip4_buffer: &'b mut [u8], + ip6_bffer: &'b mut [u8], encrypt_buffer: &EncryptBuffer, - ) -> Poll>>>> { + ) -> Poll>>>> { ready!(self.poll_send_unwritten(cx, encrypt_buffer)?); if let Poll::Ready(network) = self.sockets.poll_recv_from(ip4_buffer, ip6_bffer, cx)? { return Poll::Ready(Ok(Input::Network(network.filter(is_max_wg_packet_size)))); } - if let Poll::Ready(packet) = self.device.poll_read(device_buffer, cx)? { + if let Poll::Ready(Some(packet)) = self.inbound_packet_rx.poll_recv(cx) { return Poll::Ready(Ok(Input::Device(packet))); } @@ -105,8 +129,28 @@ impl Io { Poll::Ready(Ok(())) } - pub fn device_mut(&mut self) -> &mut Device { - &mut self.device + pub fn set_tun(&mut self, tun: Box) { + // If we can't set a new TUN device, shut down connlib. + + self.tun_tx + .try_send(tun) + .expect("Channel to set new TUN device should always have capacity"); + } + + pub fn send_tun(&mut self, packet: IpPacket) -> io::Result<()> { + let Err(e) = self.outbound_packet_tx.try_send(packet) else { + return Ok(()); + }; + + match e { + mpsc::error::TrySendError::Full(_) => { + Err(io::Error::other("Outbound packet channel is at capacity")) + } + mpsc::error::TrySendError::Closed(_) => Err(io::Error::new( + io::ErrorKind::NotConnected, + "Outbound packet channel is disconnected", + )), + } } pub fn rebind_sockets(&mut self) { @@ -156,21 +200,102 @@ impl Io { Ok(()) } +} - pub fn send_device(&self, packet: IpPacket<'_>) -> io::Result<()> { - self.device.write(packet)?; +async fn tun_send_recv( + mut tun_rx: mpsc::Receiver>, + mut outbound_packet_rx: mpsc::Receiver, + inbound_packet_tx: mpsc::Sender, +) { + let mut device = Device::new(); - Ok(()) + let mut command_stream = stream::select_all([ + new_tun_stream(&mut tun_rx), + outgoing_packet_stream(&mut outbound_packet_rx), + ]); + + loop { + match future::select( + command_stream.next(), + future::poll_fn(|cx| device.poll_read(cx)), + ) + .await + { + Either::Left((Some(Command::SendPacket(p)), _)) => { + if let Err(e) = device.write(p) { + tracing::debug!("Failed to write TUN packet: {e}"); + }; + } + Either::Left((Some(Command::UpdateTun(tun)), _)) => { + device.set_tun(tun); + } + Either::Left((None, _)) => { + tracing::debug!("Command stream closed"); + return; + } + Either::Right((Ok(p), _)) => { + if inbound_packet_tx.send(p).await.is_err() { + tracing::debug!("Inbound packet channel closed"); + return; + }; + } + Either::Right((Err(e), _)) => { + tracing::debug!("Failed to read packet from TUN device: {e}"); + return; + } + }; } } +#[expect( + clippy::large_enum_variant, + reason = "We purposely don't want to allocate each IP packet." +)] +enum Command { + UpdateTun(Box), + SendPacket(IpPacket), +} + +fn new_tun_stream( + tun_rx: &mut mpsc::Receiver>, +) -> Pin + '_>> { + Box::pin(stream::poll_fn(|cx| { + tun_rx + .poll_recv(cx) + .map(|maybe_t| maybe_t.map(Command::UpdateTun)) + })) +} + +fn outgoing_packet_stream( + outbound_packet_rx: &mut mpsc::Receiver, +) -> Pin + '_>> { + Box::pin(stream::poll_fn(|cx| { + outbound_packet_rx + .poll_recv(cx) + .map(|maybe_p| maybe_p.map(Command::SendPacket)) + })) +} + fn is_max_wg_packet_size(d: &DatagramIn) -> bool { let len = d.packet.len(); - if len > BUF_SIZE { - tracing::debug!(from = %d.from, %len, "Dropping too large datagram (max allowed: {BUF_SIZE} bytes)"); + if len > MAX_DATAGRAM_PAYLOAD { + tracing::debug!(from = %d.from, %len, "Dropping too large datagram (max allowed: {MAX_DATAGRAM_PAYLOAD} bytes)"); return false; } true } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn max_ip_channel_size_is_reasonable() { + let one_ip_packet = std::mem::size_of::(); + let max_channel_size = IP_CHANNEL_SIZE * one_ip_packet; + + assert_eq!(max_channel_size, 1_360_000); // 1.36MB is fine, we only have 2 of these channels, meaning less than 3MB additional buffer in total. + } +} diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index 05a7628e9..5b7cd1781 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -9,10 +9,11 @@ use chrono::Utc; use connlib_shared::{ callbacks, messages::{ClientId, GatewayId, Offer, Relay, RelayId, ResolveRequest, ResourceId, SecretKey}, - DomainName, PublicKey, DEFAULT_MTU, + DomainName, PublicKey, }; use io::Io; use ip_network::{Ipv4Network, Ipv6Network}; +use ip_packet::MAX_DATAGRAM_PAYLOAD; use rand::rngs::OsRng; use socket_factory::{SocketFactory, TcpSocket, UdpSocket}; use std::{ @@ -51,15 +52,6 @@ const REALM: &str = "firezone"; /// Thus, it is chosen as a safe, upper boundary that is not meant to be hit (and thus doesn't affect performance), yet acts as a safe guard, just in case. const MAX_EVENTLOOP_ITERS: u32 = 5000; -/// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag) -const WG_OVERHEAD: usize = 32; -/// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4). -const NAT46_OVERHEAD: usize = 20; -/// TURN's data channels have a 4 byte overhead. -const DATA_CHANNEL_OVERHEAD: usize = 4; - -const BUF_SIZE: usize = DEFAULT_MTU + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD; - pub type GatewayTunnel = Tunnel; pub type ClientTunnel = Tunnel; @@ -83,10 +75,6 @@ pub struct Tunnel { ip4_read_buf: Box<[u8; MAX_UDP_SIZE]>, ip6_read_buf: Box<[u8; MAX_UDP_SIZE]>, - /// Buffer for reading a single IP packet. - device_read_buf: Box<[u8; BUF_SIZE]>, - /// Buffer for decryping a single packet. - decrypt_buf: Box<[u8; BUF_SIZE]>, /// Buffer for encrypting a single packet. encrypt_buf: EncryptBuffer, } @@ -101,11 +89,9 @@ impl ClientTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: ClientState::new(private_key, known_hosts, rand::random()), - device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - encrypt_buf: EncryptBuffer::new(BUF_SIZE), - decrypt_buf: Box::new([0u8; BUF_SIZE]), + encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } @@ -123,7 +109,7 @@ impl ClientTunnel { } if let Some(packet) = self.role_state.poll_packets() { - self.io.send_device(packet)?; + self.io.send_tun(packet)?; continue; } @@ -140,7 +126,6 @@ impl ClientTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { @@ -169,12 +154,11 @@ impl ClientTunnel { received.from, received.packet, Instant::now(), - self.decrypt_buf.as_mut(), ) else { continue; }; - self.io.device_mut().write(packet)?; + self.io.send_tun(packet)?; } continue; @@ -200,11 +184,9 @@ impl GatewayTunnel { Self { io: Io::new(tcp_socket_factory, udp_socket_factory), role_state: GatewayState::new(private_key, rand::random()), - device_read_buf: Box::new([0u8; BUF_SIZE]), ip4_read_buf: Box::new([0u8; MAX_UDP_SIZE]), ip6_read_buf: Box::new([0u8; MAX_UDP_SIZE]), - encrypt_buf: EncryptBuffer::new(BUF_SIZE), - decrypt_buf: Box::new([0u8; BUF_SIZE]), + encrypt_buf: EncryptBuffer::new(MAX_DATAGRAM_PAYLOAD), } } @@ -234,7 +216,6 @@ impl GatewayTunnel { cx, self.ip4_read_buf.as_mut(), self.ip6_read_buf.as_mut(), - self.device_read_buf.as_mut(), &self.encrypt_buf, )? { Poll::Ready(io::Input::Timeout(timeout)) => { @@ -263,12 +244,11 @@ impl GatewayTunnel { received.from, received.packet, Instant::now(), - self.device_read_buf.as_mut(), ) else { continue; }; - self.io.device_mut().write(packet)?; + self.io.send_tun(packet)?; } continue; diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 0d8cb910c..f45348959 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -373,11 +373,11 @@ impl ClientOnGateway { } } - fn transform_network_to_tun<'a>( + fn transform_network_to_tun( &mut self, - packet: IpPacket<'a>, + packet: IpPacket, now: Instant, - ) -> anyhow::Result> { + ) -> anyhow::Result { let Some(state) = self.permanent_translations.get_mut(&packet.destination()) else { return Ok(packet); }; @@ -396,11 +396,7 @@ impl ClientOnGateway { Ok(packet) } - pub fn decapsulate<'a>( - &mut self, - packet: IpPacket<'a>, - now: Instant, - ) -> anyhow::Result> { + pub fn decapsulate(&mut self, packet: IpPacket, now: Instant) -> anyhow::Result { self.ensure_allowed_src(&packet)?; let packet = self.transform_network_to_tun(packet, now)?; @@ -410,11 +406,11 @@ impl ClientOnGateway { Ok(packet) } - pub fn encapsulate<'a>( + pub fn encapsulate( &mut self, - packet: IpPacket<'a>, + packet: IpPacket, now: Instant, - ) -> anyhow::Result>> { + ) -> anyhow::Result> { let Some((proto, ip)) = self.nat_table.translate_incoming(&packet, now)? else { return Ok(Some(packet)); }; @@ -433,7 +429,7 @@ impl ClientOnGateway { Ok(Some(packet)) } - fn ensure_allowed_src(&self, packet: &IpPacket<'_>) -> anyhow::Result<()> { + fn ensure_allowed_src(&self, packet: &IpPacket) -> anyhow::Result<()> { let src = packet.source(); if !self.allowed_ips().contains(&src) { @@ -444,7 +440,7 @@ impl ClientOnGateway { } /// Check if an incoming packet arriving over the network is ok to be forwarded to the TUN device. - fn ensure_allowed_dst(&mut self, packet: &IpPacket<'_>) -> anyhow::Result<()> { + fn ensure_allowed_dst(&mut self, packet: &IpPacket) -> anyhow::Result<()> { let dst = packet.destination(); if !self .filters diff --git a/rust/connlib/tunnel/src/peer/nat_table.rs b/rust/connlib/tunnel/src/peer/nat_table.rs index af2497eec..b3c1b48ee 100644 --- a/rust/connlib/tunnel/src/peer/nat_table.rs +++ b/rust/connlib/tunnel/src/peer/nat_table.rs @@ -110,7 +110,7 @@ mod tests { #[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })] fn translates_back_and_forth_packet( - #[strategy(udp_or_tcp_or_icmp_packet())] packet: IpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet: IpPacket, #[strategy(any::())] outside_dst: IpAddr, #[strategy(0..120u64)] response_delay: u64, ) { @@ -152,9 +152,9 @@ mod tests { #[test_strategy::proptest(ProptestConfig { max_local_rejects: 10_000, max_global_rejects: 10_000, ..ProptestConfig::default() })] fn can_handle_multiple_packets( - #[strategy(udp_or_tcp_or_icmp_packet())] packet1: IpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet1: IpPacket, #[strategy(any::())] outside_dst1: IpAddr, - #[strategy(udp_or_tcp_or_icmp_packet())] packet2: IpPacket<'static>, + #[strategy(udp_or_tcp_or_icmp_packet())] packet2: IpPacket, #[strategy(any::())] outside_dst2: IpAddr, ) { proptest::prop_assume!(packet1.destination().is_ipv4() == outside_dst1.is_ipv4()); // Required for our test to simulate a response. diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index 0f717c2bd..24aad8af5 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -193,8 +193,8 @@ pub(crate) fn assert_dns_packets_properties(ref_client: &RefClient, sim_client: } fn assert_correct_src_and_dst_ips( - client_sent_request: &IpPacket<'_>, - client_received_reply: &IpPacket<'_>, + client_sent_request: &IpPacket, + client_received_reply: &IpPacket, ) { let req_dst = client_sent_request.destination(); let res_src = client_received_reply.source(); @@ -216,8 +216,8 @@ fn assert_correct_src_and_dst_ips( } fn assert_correct_src_and_dst_udp_ports( - client_sent_request: &IpPacket<'_>, - client_received_reply: &IpPacket<'_>, + client_sent_request: &IpPacket, + client_received_reply: &IpPacket, ) { let client_sent_request = client_sent_request.as_udp().unwrap(); let client_received_reply = client_received_reply.as_udp().unwrap(); @@ -241,7 +241,7 @@ fn assert_correct_src_and_dst_udp_ports( } } -fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket<'_>, expected: &IpAddr) { +fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket, expected: &IpAddr) { let actual = gateway_received_request.destination(); if actual != *expected { @@ -252,7 +252,7 @@ fn assert_destination_is_cdir_resource(gateway_received_request: &IpPacket<'_>, } fn assert_destination_is_dns_resource( - gateway_received_request: &IpPacket<'_>, + gateway_received_request: &IpPacket, global_dns_records: &BTreeMap>, domain: &DomainName, ) { @@ -275,8 +275,8 @@ fn assert_destination_is_dns_resource( /// Yet, we care that it remains stable to ensure that any form of sticky sessions don't get broken (i.e. packets to one IP are always routed to the same IP on the gateway). /// To assert this, we build up a map as we iterate through all packets that have been sent. fn assert_proxy_ip_mapping_is_stable( - client_sent_request: &IpPacket<'_>, - gateway_received_request: &IpPacket<'_>, + client_sent_request: &IpPacket, + gateway_received_request: &IpPacket, mapping: &mut HashMap, ) { let proxy_ip = client_sent_request.destination(); diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index cfd5c45ad..b4ba2e462 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -53,13 +53,12 @@ pub(crate) struct SimClient { pub(crate) ipv4_routes: BTreeSet, pub(crate) ipv6_routes: BTreeSet, - pub(crate) sent_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket<'static>>, - pub(crate) received_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket<'static>>, + pub(crate) sent_dns_queries: HashMap<(SocketAddr, QueryId), IpPacket>, + pub(crate) received_dns_responses: BTreeMap<(SocketAddr, QueryId), IpPacket>, - pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket<'static>>, - pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket<'static>>, + pub(crate) sent_icmp_requests: HashMap<(u16, u16), IpPacket>, + pub(crate) received_icmp_replies: BTreeMap<(u16, u16), IpPacket>, - buffer: Vec, enc_buffer: EncryptBuffer, } @@ -74,7 +73,6 @@ impl SimClient { received_dns_responses: Default::default(), sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), - buffer: vec![0u8; (1 << 16) - 1], enc_buffer: EncryptBuffer::new((1 << 16) - 1), ipv4_routes: Default::default(), ipv6_routes: Default::default(), @@ -120,7 +118,7 @@ impl SimClient { pub(crate) fn encapsulate( &mut self, - packet: IpPacket<'static>, + packet: IpPacket, now: Instant, ) -> Option> { if let Some(icmp) = packet.as_icmpv4() { @@ -164,22 +162,18 @@ impl SimClient { } pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { - let Some(packet) = self.sut.decapsulate( - transmit.dst, - transmit.src.unwrap(), - &transmit.payload, - now, - &mut self.buffer, - ) else { + let Some(packet) = + self.sut + .decapsulate(transmit.dst, transmit.src.unwrap(), &transmit.payload, now) + else { return; }; - let packet = packet.to_owned(); self.on_received_packet(packet); } /// Process an IP packet received on the client. - pub(crate) fn on_received_packet(&mut self, packet: IpPacket<'static>) { + pub(crate) fn on_received_packet(&mut self, packet: IpPacket) { if let Some(icmp) = packet.as_icmpv4() { if let Icmpv4Type::EchoReply(echo) = icmp.icmp_type() { self.received_icmp_replies diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 062e7cb6a..92ad6b741 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -24,9 +24,8 @@ pub(crate) struct SimGateway { pub(crate) sut: GatewayState, /// The received ICMP packets, indexed by our custom ICMP payload. - pub(crate) received_icmp_requests: BTreeMap>, + pub(crate) received_icmp_requests: BTreeMap, - buffer: Vec, enc_buffer: EncryptBuffer, } @@ -36,7 +35,6 @@ impl SimGateway { id, sut, received_icmp_requests: Default::default(), - buffer: vec![0u8; (1 << 16) - 1], enc_buffer: EncryptBuffer::new((1 << 16) - 1), } } @@ -47,16 +45,9 @@ impl SimGateway { transmit: Transmit, now: Instant, ) -> Option> { - let packet = self - .sut - .decapsulate( - transmit.dst, - transmit.src.unwrap(), - &transmit.payload, - now, - &mut self.buffer, - )? - .to_owned(); + let packet = + self.sut + .decapsulate(transmit.dst, transmit.src.unwrap(), &transmit.payload, now)?; self.on_received_packet(global_dns_records, packet, now) } @@ -65,7 +56,7 @@ impl SimGateway { fn on_received_packet( &mut self, global_dns_records: &BTreeMap>, - packet: IpPacket<'static>, + packet: IpPacket, now: Instant, ) -> Option> { // TODO: Instead of handling these things inline, here, should we dispatch them via `RoutingTable`? @@ -115,7 +106,7 @@ impl SimGateway { fn handle_icmp_request( &mut self, - packet: &IpPacket<'static>, + packet: &IpPacket, echo: IcmpEchoHeader, payload: &[u8], now: Instant, diff --git a/rust/gateway/Cargo.toml b/rust/gateway/Cargo.toml index 579385102..59762209e 100644 --- a/rust/gateway/Cargo.toml +++ b/rust/gateway/Cargo.toml @@ -21,6 +21,7 @@ firezone-logging = { workspace = true } firezone-tunnel = { workspace = true } futures = "0.3.29" futures-bounded = { workspace = true } +ip-packet = { workspace = true } ip_network = { version = "0.4", default-features = false } libc = { version = "0.2", default-features = false, features = ["std", "const-extern-fn", "extra_traits"] } phoenix-channel = { workspace = true } diff --git a/rust/gateway/src/main.rs b/rust/gateway/src/main.rs index 708ab7d3c..537221dd0 100644 --- a/rust/gateway/src/main.rs +++ b/rust/gateway/src/main.rs @@ -2,7 +2,7 @@ use crate::eventloop::{Eventloop, PHOENIX_TOPIC}; use anyhow::{Context, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; -use connlib_shared::{get_user_agent, messages::Interface, LoginUrl, StaticSecret, DEFAULT_MTU}; +use connlib_shared::{get_user_agent, messages::Interface, LoginUrl, StaticSecret}; use firezone_bin_shared::{ http_health_check, linux::{tcp_socket_factory, udp_socket_factory}, @@ -123,7 +123,7 @@ async fn run(login: LoginUrl, private_key: StaticSecret) -> Result { )?; let (sender, receiver) = mpsc::channel::(10); - let mut tun_device_manager = TunDeviceManager::new(DEFAULT_MTU)?; + let mut tun_device_manager = TunDeviceManager::new(ip_packet::PACKET_SIZE)?; let tun = tun_device_manager.make_tun()?; tunnel.set_tun(Box::new(tun)); diff --git a/rust/headless-client/Cargo.toml b/rust/headless-client/Cargo.toml index b4bda3f30..f1a4c2b20 100644 --- a/rust/headless-client/Cargo.toml +++ b/rust/headless-client/Cargo.toml @@ -10,13 +10,14 @@ authors = ["Firezone, Inc."] anyhow = { version = "1.0" } atomicwrites = { workspace = true } # Needed to safely backup `/etc/resolv.conf` and write the device ID on behalf of `gui-client` backoff = "0.4.0" -clap = { version = "4.5", features = ["derive", "env", "string"] } +clap = { version = "4.5", features = ["derive", "env", "string"] } connlib-client-shared = { workspace = true } connlib-shared = { workspace = true } firezone-bin-shared = { workspace = true } firezone-logging = { workspace = true } futures = "0.3.30" humantime = "2.1" +ip-packet = { workspace = true } ip_network = { version = "0.4", default-features = false } phoenix-channel = { workspace = true } rustls = { workspace = true } diff --git a/rust/headless-client/src/ipc_service.rs b/rust/headless-client/src/ipc_service.rs index 8fe27afdc..82d03779b 100644 --- a/rust/headless-client/src/ipc_service.rs +++ b/rust/headless-client/src/ipc_service.rs @@ -25,7 +25,7 @@ use url::Url; pub mod ipc; use backoff::ExponentialBackoffBuilder; -use connlib_shared::{get_user_agent, messages::ResourceId, DEFAULT_MTU}; +use connlib_shared::{get_user_agent, messages::ResourceId}; use ipc::{Server as IpcServer, ServiceId}; use phoenix_channel::PhoenixChannel; use secrecy::Secret; @@ -277,7 +277,7 @@ impl<'a> Handler<'a> { .next_client_split() .await .context("Failed to wait for incoming IPC connection from a GUI")?; - let tun_device = TunDeviceManager::new(DEFAULT_MTU)?; + let tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?; Ok(Self { dns_controller, diff --git a/rust/headless-client/src/main.rs b/rust/headless-client/src/main.rs index 3f236968e..2d892a9cd 100644 --- a/rust/headless-client/src/main.rs +++ b/rust/headless-client/src/main.rs @@ -4,7 +4,7 @@ use anyhow::{anyhow, Context as _, Result}; use backoff::ExponentialBackoffBuilder; use clap::Parser; use connlib_client_shared::{keypair, ConnectArgs, LoginUrl, Session}; -use connlib_shared::{get_user_agent, DEFAULT_MTU}; +use connlib_shared::get_user_agent; use firezone_bin_shared::{ new_dns_notifier, new_network_notifier, platform::{tcp_socket_factory, udp_socket_factory}, @@ -213,7 +213,7 @@ fn main() -> Result<()> { // Deactivate Firezone DNS control in case the system or IPC service crashed // and we need to recover. dns_controller.deactivate()?; - let mut tun_device = TunDeviceManager::new(DEFAULT_MTU)?; + let mut tun_device = TunDeviceManager::new(ip_packet::PACKET_SIZE)?; let mut cb_rx = ReceiverStream::new(cb_rx).fuse(); let tokio_handle = tokio::runtime::Handle::current(); diff --git a/rust/ip-packet/src/lib.rs b/rust/ip-packet/src/lib.rs index 51c912642..8ac2ea5b6 100644 --- a/rust/ip-packet/src/lib.rs +++ b/rust/ip-packet/src/lib.rs @@ -21,13 +21,22 @@ use icmpv4_header_slice_mut::Icmpv4HeaderSliceMut; use icmpv6_header_slice_mut::Icmpv6EchoHeaderSliceMut; use ipv4_header_slice_mut::Ipv4HeaderSliceMut; use ipv6_header_slice_mut::Ipv6HeaderSliceMut; -use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr}, - ops::{Deref, DerefMut}, -}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use tcp_header_slice_mut::TcpHeaderSliceMut; use udp_header_slice_mut::UdpHeaderSliceMut; +/// The maximum size of an IP packet we can handle. +pub const PACKET_SIZE: usize = 1280; +/// The maximum payload of a UDP packet that carries an encrypted IP packet. +pub const MAX_DATAGRAM_PAYLOAD: usize = + PACKET_SIZE + WG_OVERHEAD + NAT46_OVERHEAD + DATA_CHANNEL_OVERHEAD; +/// Wireguard has a 32-byte overhead (4b message type + 4b receiver idx + 8b packet counter + 16b AEAD tag) +const WG_OVERHEAD: usize = 32; +/// In order to do NAT46 without copying, we need 20 extra byte in the buffer (IPv6 packets are 20 byte bigger than IPv4). +const NAT46_OVERHEAD: usize = 20; +/// TURN's data channels have a 4 byte overhead. +const DATA_CHANNEL_OVERHEAD: usize = 4; + macro_rules! for_both { ($this:ident, |$name:ident| $body:expr) => { match $this { @@ -74,13 +83,36 @@ impl Protocol { } } -#[derive(PartialEq, Clone)] -pub enum IpPacket<'a> { - Ipv4(ConvertibleIpv4Packet<'a>), - Ipv6(ConvertibleIpv6Packet<'a>), +/// A buffer for reading a new [`IpPacket`] from the network. +pub struct IpPacketBuf { + inner: [u8; MAX_DATAGRAM_PAYLOAD], } -impl<'a> std::fmt::Debug for IpPacket<'a> { +impl IpPacketBuf { + pub fn new() -> Self { + Self { + inner: [0u8; MAX_DATAGRAM_PAYLOAD], + } + } + + pub fn buf(&mut self) -> &mut [u8] { + &mut self.inner[NAT46_OVERHEAD..] // We read packets at an offset so we can convert without copying. + } +} + +impl Default for IpPacketBuf { + fn default() -> Self { + Self::new() + } +} + +#[derive(PartialEq, Clone)] +pub enum IpPacket { + Ipv4(ConvertibleIpv4Packet), + Ipv6(ConvertibleIpv6Packet), +} + +impl std::fmt::Debug for IpPacket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut dbg = f.debug_struct("Packet"); @@ -117,71 +149,23 @@ impl<'a> std::fmt::Debug for IpPacket<'a> { } } -#[derive(Debug, PartialEq)] -enum MaybeOwned<'a> { - RefMut(&'a mut [u8]), - Owned(Vec), -} - -impl<'a> MaybeOwned<'a> { - fn remove_from_head(self, bytes: usize) -> MaybeOwned<'a> { - match self { - MaybeOwned::RefMut(ref_mut) => MaybeOwned::RefMut(&mut ref_mut[bytes..]), - MaybeOwned::Owned(mut owned) => { - owned.drain(0..bytes); - MaybeOwned::Owned(owned) - } - } - } -} - -impl<'a> Clone for MaybeOwned<'a> { - fn clone(&self) -> Self { - match self { - Self::RefMut(i) => Self::Owned(i.to_vec()), - Self::Owned(i) => Self::Owned(i.clone()), - } - } -} - -impl<'a> Deref for MaybeOwned<'a> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - match self { - MaybeOwned::RefMut(ref_mut) => ref_mut, - MaybeOwned::Owned(owned) => owned, - } - } -} - -impl<'a> DerefMut for MaybeOwned<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - MaybeOwned::RefMut(ref_mut) => ref_mut, - MaybeOwned::Owned(owned) => owned, - } - } -} - #[derive(Debug, PartialEq, Clone)] -pub struct ConvertibleIpv4Packet<'a> { - buf: MaybeOwned<'a>, +pub struct ConvertibleIpv4Packet { + buf: [u8; MAX_DATAGRAM_PAYLOAD], + start: usize, + len: usize, } -impl<'a> ConvertibleIpv4Packet<'a> { - pub fn new(buf: &'a mut [u8]) -> Option> { - Ipv4HeaderSlice::from_slice(&buf[20..]).ok()?; - Some(Self { - buf: MaybeOwned::RefMut(buf), - }) - } +impl ConvertibleIpv4Packet { + pub fn new(ip: IpPacketBuf, len: usize) -> Option { + let this = Self { + buf: ip.inner, + start: NAT46_OVERHEAD, + len, + }; + Ipv4HeaderSlice::from_slice(this.packet()).ok()?; - fn owned(buf: Vec) -> Option> { - Ipv4HeaderSlice::from_slice(&buf[20..]).ok()?; - Some(Self { - buf: MaybeOwned::Owned(buf), - }) + Some(this) } fn ip_header(&self) -> Ipv4HeaderSlice { @@ -200,17 +184,34 @@ impl<'a> ConvertibleIpv4Packet<'a> { self.ip_header().destination_addr() } - fn consume_to_ipv6( - mut self, - src: Ipv6Addr, - dst: Ipv6Addr, - ) -> Option> { - let offset = nat46::translate_in_place(&mut self.buf, src, dst) - .inspect_err(|e| tracing::trace!("NAT64 failed: {e:#}")) - .ok()?; - let buf = self.buf.remove_from_head(offset); + fn consume_to_ipv6(mut self, src: Ipv6Addr, dst: Ipv6Addr) -> Option { + // `translate_in_place` expects the packet to sit at a 20-byte offset. + // `self.start` tells us where the packet actually starts, thus we need to pass `self.start - 20` to the function. + let start_minus_padding = self.start.checked_sub(NAT46_OVERHEAD)?; - Some(ConvertibleIpv6Packet { buf }) + let offset = nat46::translate_in_place( + &mut self.buf[start_minus_padding..(self.start + self.len)], + src, + dst, + ) + .inspect_err(|e| tracing::trace!("NAT46 failed: {e:#}")) + .ok()?; + + // We need to handle 2 cases here: + // `offset` > NAT46_OVERHEAD + // `offset` < NAT46_OVERHEAD + // By casting to an `isize` we can simply compute the _difference_ of the packet length. + // `offset` points into the buffer we passed to `translate_in_place`. + // We passed 20 (NAT46_OVERHEAD) bytes more to that function. + // Thus, 20 - offset tells us the difference in length of the new packet. + let len_diff = (NAT46_OVERHEAD as isize) - (offset as isize); + let len = (self.len as isize) + len_diff; + + Some(ConvertibleIpv6Packet { + buf: self.buf, + start: start_minus_padding + offset, + len: len as usize, + }) } fn header_length(&self) -> usize { @@ -218,34 +219,32 @@ impl<'a> ConvertibleIpv4Packet<'a> { } pub fn packet(&self) -> &[u8] { - &self.buf[20..] + &self.buf[self.start..(self.start + self.len)] } fn packet_mut(&mut self) -> &mut [u8] { - &mut self.buf[20..] + &mut self.buf[self.start..(self.start + self.len)] } } #[derive(Debug, PartialEq, Clone)] -pub struct ConvertibleIpv6Packet<'a> { - buf: MaybeOwned<'a>, +pub struct ConvertibleIpv6Packet { + buf: [u8; MAX_DATAGRAM_PAYLOAD], + start: usize, + len: usize, } -impl<'a> ConvertibleIpv6Packet<'a> { - pub fn new(buf: &'a mut [u8]) -> Option> { - Ipv6HeaderSlice::from_slice(buf).ok()?; +impl ConvertibleIpv6Packet { + pub fn new(ip: IpPacketBuf, len: usize) -> Option { + let this = Self { + buf: ip.inner, + start: NAT46_OVERHEAD, + len, + }; - Some(Self { - buf: MaybeOwned::RefMut(buf), - }) - } + Ipv6HeaderSlice::from_slice(this.packet()).ok()?; - fn owned(buf: Vec) -> Option> { - Ipv6HeaderSlice::from_slice(&buf).ok()?; - - Some(Self { - buf: MaybeOwned::Owned(buf), - }) + Some(this) } fn header(&self) -> Ipv6HeaderSlice { @@ -265,24 +264,24 @@ impl<'a> ConvertibleIpv6Packet<'a> { self.header().destination_addr() } - fn consume_to_ipv4( - mut self, - src: Ipv4Addr, - dst: Ipv4Addr, - ) -> Option> { - nat64::translate_in_place(&mut self.buf, src, dst) + fn consume_to_ipv4(mut self, src: Ipv4Addr, dst: Ipv4Addr) -> Option { + nat64::translate_in_place(self.packet_mut(), src, dst) .inspect_err(|e| tracing::trace!("NAT64 failed: {e:#}")) .ok()?; - Some(ConvertibleIpv4Packet { buf: self.buf }) + Some(ConvertibleIpv4Packet { + buf: self.buf, + start: self.start + 20, + len: self.len - 20, + }) } pub fn packet(&self) -> &[u8] { - &self.buf + &self.buf[self.start..(self.start + self.len)] } fn packet_mut(&mut self) -> &mut [u8] { - &mut self.buf + &mut self.buf[self.start..(self.start + self.len)] } } @@ -318,48 +317,23 @@ pub fn ipv6_translated(ip: Ipv6Addr) -> Option { )) } -impl<'a> IpPacket<'a> { - // TODO: this API is a bit akward, since you have to pass the extra prepended 20 bytes - pub fn new(buf: &'a mut [u8]) -> Option { - match buf[20] >> 4 { - 4 => Some(IpPacket::Ipv4(ConvertibleIpv4Packet::new(buf)?)), - 6 => Some(IpPacket::Ipv6(ConvertibleIpv6Packet::new(&mut buf[20..])?)), +impl IpPacket { + pub fn new(buf: IpPacketBuf, len: usize) -> Option { + match buf.inner[NAT46_OVERHEAD] >> 4 { + 4 => Some(IpPacket::Ipv4(ConvertibleIpv4Packet::new(buf, len)?)), + 6 => Some(IpPacket::Ipv6(ConvertibleIpv6Packet::new(buf, len)?)), _ => None, } } - pub(crate) fn owned(mut data: Vec) -> Option> { - let packet = match data[20] >> 4 { - 4 => ConvertibleIpv4Packet::owned(data)?.into(), - 6 => { - data.drain(0..20); - ConvertibleIpv6Packet::owned(data)?.into() - } - _ => return None, - }; - - Some(packet) - } - - pub fn to_owned(&self) -> IpPacket<'static> { - match self { - IpPacket::Ipv4(i) => IpPacket::Ipv4(ConvertibleIpv4Packet { - buf: MaybeOwned::Owned(i.buf.to_vec()), - }), - IpPacket::Ipv6(i) => IpPacket::Ipv6(ConvertibleIpv6Packet { - buf: MaybeOwned::Owned(i.buf.to_vec()), - }), - } - } - - pub(crate) fn consume_to_ipv4(self, src: Ipv4Addr, dst: Ipv4Addr) -> Option> { + pub(crate) fn consume_to_ipv4(self, src: Ipv4Addr, dst: Ipv4Addr) -> Option { match self { IpPacket::Ipv4(pkt) => Some(IpPacket::Ipv4(pkt)), IpPacket::Ipv6(pkt) => Some(IpPacket::Ipv4(pkt.consume_to_ipv4(src, dst)?)), } } - pub(crate) fn consume_to_ipv6(self, src: Ipv6Addr, dst: Ipv6Addr) -> Option> { + pub(crate) fn consume_to_ipv6(self, src: Ipv6Addr, dst: Ipv6Addr) -> Option { match self { IpPacket::Ipv4(pkt) => Some(IpPacket::Ipv6(pkt.consume_to_ipv6(src, dst)?)), IpPacket::Ipv6(pkt) => Some(IpPacket::Ipv6(pkt)), @@ -668,7 +642,7 @@ impl<'a> IpPacket<'a> { src_v6: Ipv6Addr, src_proto: Protocol, dst: IpAddr, - ) -> Option> { + ) -> Option { let mut packet = match (&self, dst) { (&IpPacket::Ipv4(_), IpAddr::V6(dst)) => self.consume_to_ipv6(src_v6, dst)?, (&IpPacket::Ipv6(_), IpAddr::V4(dst)) => self.consume_to_ipv4(src_v4, dst)?, @@ -688,7 +662,7 @@ impl<'a> IpPacket<'a> { dst_v6: Ipv6Addr, dst_proto: Protocol, src: IpAddr, - ) -> Option> { + ) -> Option { let mut packet = match (&self, src) { (&IpPacket::Ipv4(_), IpAddr::V6(src)) => self.consume_to_ipv6(src, dst_v6)?, (&IpPacket::Ipv6(_), IpAddr::V4(src)) => self.consume_to_ipv4(src, dst_v4)?, @@ -740,11 +714,7 @@ impl<'a> IpPacket<'a> { pub fn ipv4_header(&self) -> Option { match self { - Self::Ipv4(p) => Some( - Ipv4HeaderSlice::from_slice(p.packet()) - .expect("Should be a valid packet") - .to_header(), - ), + Self::Ipv4(p) => Some(p.ip_header().to_header()), Self::Ipv6(_) => None, } } @@ -752,11 +722,7 @@ impl<'a> IpPacket<'a> { pub fn ipv6_header(&self) -> Option { match self { Self::Ipv4(_) => None, - Self::Ipv6(p) => Some( - Ipv6HeaderSlice::from_slice(p.packet()) - .expect("Should be a valid packet") - .to_header(), - ), + Self::Ipv6(p) => Some(p.header().to_header()), } } @@ -817,14 +783,14 @@ impl<'a> IpPacket<'a> { } } -impl<'a> From> for IpPacket<'a> { - fn from(value: ConvertibleIpv4Packet<'a>) -> Self { +impl From for IpPacket { + fn from(value: ConvertibleIpv4Packet) -> Self { Self::Ipv4(value) } } -impl<'a> From> for IpPacket<'a> { - fn from(value: ConvertibleIpv6Packet<'a>) -> Self { +impl From for IpPacket { + fn from(value: ConvertibleIpv6Packet) -> Self { Self::Ipv6(value) } } diff --git a/rust/ip-packet/src/make.rs b/rust/ip-packet/src/make.rs index c08086f33..852415860 100644 --- a/rust/ip-packet/src/make.rs +++ b/rust/ip-packet/src/make.rs @@ -16,13 +16,13 @@ use std::net::{IpAddr, SocketAddr}; macro_rules! build { ($packet:expr, $payload:ident) => {{ let size = $packet.size($payload.len()); - let mut buf = vec![0u8; size + 20]; + let mut ip = $crate::IpPacketBuf::new(); $packet - .write(&mut std::io::Cursor::new(&mut buf[20..]), &$payload) + .write(&mut std::io::Cursor::new(ip.buf()), &$payload) .expect("Buffer should be big enough"); - IpPacket::owned(buf).expect("Should be a valid IP packet") + IpPacket::new(ip, size).expect("Should be a valid IP packet") }}; } @@ -32,7 +32,7 @@ pub fn icmp_request_packet( seq: u16, identifier: u16, payload: &[u8], -) -> Result, IpVersionMismatch> { +) -> Result { match (src, dst.into()) { (IpAddr::V4(src), IpAddr::V4(dst)) => { let packet = PacketBuilder::ipv4(src.octets(), dst.octets(), 64) @@ -56,7 +56,7 @@ pub fn icmp_reply_packet( seq: u16, identifier: u16, payload: &[u8], -) -> Result, IpVersionMismatch> { +) -> Result { match (src, dst.into()) { (IpAddr::V4(src), IpAddr::V4(dst)) => { let packet = PacketBuilder::ipv4(src.octets(), dst.octets(), 64) @@ -80,7 +80,7 @@ pub fn tcp_packet( sport: u16, dport: u16, payload: Vec, -) -> Result, IpVersionMismatch> +) -> Result where IP: Into, { @@ -107,7 +107,7 @@ pub fn udp_packet( sport: u16, dport: u16, payload: Vec, -) -> Result, IpVersionMismatch> +) -> Result where IP: Into, { @@ -132,7 +132,7 @@ pub fn dns_query( src: SocketAddr, dst: SocketAddr, id: u16, -) -> Result, IpVersionMismatch> { +) -> Result { // Create the DNS query message let mut msg_builder = MessageBuilder::new_vec(); @@ -152,10 +152,7 @@ pub fn dns_query( } /// Makes a DNS response to the given DNS query packet, using a resolver callback. -pub fn dns_ok_response( - packet: IpPacket<'static>, - resolve: impl Fn(&Name>) -> I, -) -> IpPacket<'static> +pub fn dns_ok_response(packet: IpPacket, resolve: impl Fn(&Name>) -> I) -> IpPacket where I: Iterator, { diff --git a/rust/ip-packet/src/nat46.rs b/rust/ip-packet/src/nat46.rs index 19936bd75..5320a35e2 100644 --- a/rust/ip-packet/src/nat46.rs +++ b/rust/ip-packet/src/nat46.rs @@ -6,16 +6,18 @@ use etherparse::{ }; use std::{io::Cursor, net::Ipv6Addr}; +use crate::NAT46_OVERHEAD; + /// Performs IPv4 -> IPv6 NAT on the packet in `buf` to the given src & dst IP. /// /// An IPv6 IP-header may be up to 20 bytes bigger than its corresponding IPv4 counterpart. -/// Thus, the IPv4 packet is expected to sit at an offset of 20 bytes in `buf`. +/// Thus, the IPv4 packet is expected to sit at an offset of [`NAT46_OVERHEAD`] bytes in `buf`. /// /// # Returns /// /// - Ok(offset): The offset within `buf` at which the new IPv6 packet starts. pub fn translate_in_place(buf: &mut [u8], src: Ipv6Addr, dst: Ipv6Addr) -> Result { - let ipv4_packet = &buf[20..]; + let ipv4_packet = &buf[NAT46_OVERHEAD..]; let (headers, payload) = etherparse::IpHeaders::from_ipv4_slice(ipv4_packet)?; let (ipv4_header, _extensions) = headers.ipv4().expect("We successfully parsed as IPv4"); @@ -149,12 +151,10 @@ pub fn translate_in_place(buf: &mut [u8], src: Ipv6Addr, dst: Ipv6Addr) -> Resul let start_of_ipv6_header = start_of_ip_payload - Ipv6Header::LEN; - let (excess_padding, ipv6_header_buf) = buf.split_at_mut(start_of_ipv6_header); + let (_, ipv6_header_buf) = buf.split_at_mut(start_of_ipv6_header); ipv6_header.write(&mut Cursor::new(ipv6_header_buf))?; - let excess_padding_length = excess_padding.len(); - - Ok(excess_padding_length) + Ok(start_of_ipv6_header) } fn translate_icmpv4_header( diff --git a/rust/ip-packet/src/proptest.rs b/rust/ip-packet/src/proptest.rs index c68d6b564..446192b0f 100644 --- a/rust/ip-packet/src/proptest.rs +++ b/rust/ip-packet/src/proptest.rs @@ -2,7 +2,7 @@ use crate::IpPacket; use proptest::{arbitrary::any, prop_oneof, strategy::Strategy}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -pub fn udp_packet() -> impl Strategy> { +pub fn udp_packet() -> impl Strategy { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::udp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() @@ -13,7 +13,7 @@ pub fn udp_packet() -> impl Strategy> { ] } -pub fn tcp_packet() -> impl Strategy> { +pub fn tcp_packet() -> impl Strategy { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::tcp_packet(saddr, daddr, sport, dport, Vec::new()).unwrap() @@ -24,7 +24,7 @@ pub fn tcp_packet() -> impl Strategy> { ] } -pub fn icmp_request_packet() -> impl Strategy> { +pub fn icmp_request_packet() -> impl Strategy { prop_oneof![ (ip4_tuple(), any::(), any::()).prop_map(|((saddr, daddr), sport, dport)| { crate::make::icmp_request_packet(IpAddr::V4(saddr), daddr, sport, dport, &[]).unwrap() @@ -35,7 +35,7 @@ pub fn icmp_request_packet() -> impl Strategy> { ] } -pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy> { +pub fn udp_or_tcp_or_icmp_packet() -> impl Strategy { prop_oneof![udp_packet(), tcp_packet(), icmp_request_packet()] } diff --git a/rust/ip-packet/src/proptests.rs b/rust/ip-packet/src/proptests.rs index c2a69ef96..e45b2808f 100644 --- a/rust/ip-packet/src/proptests.rs +++ b/rust/ip-packet/src/proptests.rs @@ -10,7 +10,7 @@ use proptest::prelude::Just; const EMPTY_PAYLOAD: &[u8] = &[]; -fn tcp_packet_v4() -> impl Strategy> { +fn tcp_packet_v4() -> impl Strategy { ( any::(), any::(), @@ -26,7 +26,7 @@ fn tcp_packet_v4() -> impl Strategy> { }) } -fn tcp_packet_v6() -> impl Strategy> { +fn tcp_packet_v6() -> impl Strategy { ( any::(), any::(), @@ -42,7 +42,7 @@ fn tcp_packet_v6() -> impl Strategy> { }) } -fn udp_packet_v4() -> impl Strategy> { +fn udp_packet_v4() -> impl Strategy { ( any::(), any::(), @@ -58,7 +58,7 @@ fn udp_packet_v4() -> impl Strategy> { }) } -fn udp_packet_v6() -> impl Strategy> { +fn udp_packet_v6() -> impl Strategy { ( any::(), any::(), @@ -74,7 +74,7 @@ fn udp_packet_v6() -> impl Strategy> { }) } -fn icmp_request_packet_v4() -> impl Strategy> { +fn icmp_request_packet_v4() -> impl Strategy { ( any::(), any::(), @@ -98,7 +98,7 @@ fn icmp_request_packet_v4() -> impl Strategy> { }) } -fn icmp_reply_packet_v4() -> impl Strategy> { +fn icmp_reply_packet_v4() -> impl Strategy { ( any::(), any::(), @@ -122,7 +122,7 @@ fn icmp_reply_packet_v4() -> impl Strategy> { }) } -fn icmp_request_packet_v6() -> impl Strategy> { +fn icmp_request_packet_v6() -> impl Strategy { ( any::(), any::(), @@ -137,7 +137,7 @@ fn icmp_request_packet_v6() -> impl Strategy> { }) } -fn icmp_reply_packet_v6() -> impl Strategy> { +fn icmp_reply_packet_v6() -> impl Strategy { ( any::(), any::(), @@ -168,7 +168,7 @@ fn ipv4_options() -> impl Strategy { ] } -fn packet_v4() -> impl Strategy> { +fn packet_v4() -> impl Strategy { prop_oneof![ tcp_packet_v4(), udp_packet_v4(), @@ -177,7 +177,7 @@ fn packet_v4() -> impl Strategy> { ] } -fn packet_v6() -> impl Strategy> { +fn packet_v6() -> impl Strategy { prop_oneof![ tcp_packet_v6(), udp_packet_v6(), @@ -188,7 +188,7 @@ fn packet_v6() -> impl Strategy> { #[test_strategy::proptest()] fn nat_6446( - #[strategy(packet_v6())] packet_v6: IpPacket<'static>, + #[strategy(packet_v6())] packet_v6: IpPacket, #[strategy(any::())] new_src: Ipv4Addr, #[strategy(any::())] new_dst: Ipv4Addr, ) { @@ -211,7 +211,7 @@ fn nat_6446( #[test_strategy::proptest()] fn nat_4664( - #[strategy(packet_v4())] packet_v4: IpPacket<'static>, + #[strategy(packet_v4())] packet_v4: IpPacket, #[strategy(any::())] new_src: Ipv6Addr, #[strategy(any::())] new_dst: Ipv6Addr, ) {