diff --git a/.github/workflows/_rust.yml b/.github/workflows/_rust.yml index de9b62364..b95a57370 100644 --- a/.github/workflows/_rust.yml +++ b/.github/workflows/_rust.yml @@ -100,7 +100,7 @@ jobs: # Poor man's test coverage testing: Grep the generated logs for specific patterns / lines. rg --count --no-ignore SendIcmpPacket "$TESTCASES_DIR" rg --count --no-ignore SendUdpPacket "$TESTCASES_DIR" - rg --count --no-ignore SendTcpPayload "$TESTCASES_DIR" + rg --count --no-ignore ConnectTcp "$TESTCASES_DIR" rg --count --no-ignore SendDnsQueries "$TESTCASES_DIR" rg --count --no-ignore "Packet for DNS resource" "$TESTCASES_DIR" rg --count --no-ignore "Packet for CIDR resource" "$TESTCASES_DIR" diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 57f921321..0224515f0 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -2035,8 +2035,8 @@ dependencies = [ "futures", "ip-packet", "ip_network", + "l3-tcp", "rand 0.8.5", - "smoltcp", "tokio", "tracing", "tun", @@ -2631,6 +2631,7 @@ dependencies = [ "ip_network", "ip_network_table", "itertools 0.14.0", + "l3-tcp", "l4-tcp-dns-server", "l4-udp-dns-server", "lru", @@ -4026,6 +4027,16 @@ dependencies = [ "selectors", ] +[[package]] +name = "l3-tcp" +version = "0.1.0" +dependencies = [ + "anyhow", + "ip-packet", + "smoltcp", + "tracing", +] + [[package]] name = "l4-tcp-dns-server" version = "0.1.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index bcbf643ca..14d73b61f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -9,6 +9,7 @@ members = [ "connlib/dns-types", "connlib/etherparse-ext", "connlib/ip-packet", + "connlib/l3-tcp", "connlib/l4-tcp-dns-server", "connlib/l4-udp-dns-server", "connlib/model", @@ -96,6 +97,7 @@ jemallocator = "0.5.4" jni = "0.21.1" keyring = "3.6.2" known-folders = "1.2.0" +l3-tcp = { path = "connlib/l3-tcp" } l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" } l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" } libc = "0.2.174" diff --git a/rust/connlib/dns-over-tcp/Cargo.toml b/rust/connlib/dns-over-tcp/Cargo.toml index 08e7e71f4..bf2e3d55f 100644 --- a/rust/connlib/dns-over-tcp/Cargo.toml +++ b/rust/connlib/dns-over-tcp/Cargo.toml @@ -10,8 +10,8 @@ anyhow = { workspace = true } dns-types = { workspace = true } firezone-logging = { workspace = true } ip-packet = { workspace = true } +l3-tcp = { workspace = true } rand = { workspace = true } -smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } tracing = { workspace = true } [dev-dependencies] diff --git a/rust/connlib/dns-over-tcp/src/client.rs b/rust/connlib/dns-over-tcp/src/client.rs index 65b7335a1..d6680d158 100644 --- a/rust/connlib/dns-over-tcp/src/client.rs +++ b/rust/connlib/dns-over-tcp/src/client.rs @@ -4,17 +4,13 @@ use std::{ time::{Duration, Instant}, }; -use crate::{ - codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice, - time::smol_now, -}; +use crate::codec; use anyhow::{Context as _, Result, anyhow, bail}; use ip_packet::IpPacket; -use rand::{Rng, SeedableRng, rngs::StdRng}; -use smoltcp::{ - iface::{Interface, PollResult, SocketSet}, - socket::tcp, +use l3_tcp::{ + InMemoryDevice, Interface, PollResult, SocketSet, create_interface, create_tcp_socket, }; +use rand::{Rng, SeedableRng, rngs::StdRng}; /// A sans-io DNS-over-TCP client. /// @@ -32,8 +28,8 @@ pub struct Client { source_ips: Option<(Ipv4Addr, Ipv6Addr)>, sockets: SocketSet<'static>, - sockets_by_remote: BTreeMap, - local_ports_by_socket: HashMap, + sockets_by_remote: BTreeMap, + local_ports_by_socket: HashMap, /// Queries we should send to a DNS resolver. pending_queries_by_remote: HashMap>, /// Queries we have sent to a DNS resolver and are waiting for a reply. @@ -182,7 +178,7 @@ impl Client { }; let result = self.interface.poll( - smol_now(self.created_at, now), + l3_tcp::now(self.created_at, now), &mut self.device, &mut self.sockets, ); @@ -194,7 +190,7 @@ impl Client { for (remote, handle) in self.sockets_by_remote.iter_mut() { let _guard = tracing::trace_span!("socket", %handle).entered(); - let socket = self.sockets.get_mut::(*handle); + let socket = self.sockets.get_mut::(*handle); let server = *remote; let pending_queries = self.pending_queries_by_remote.entry(server).or_default(); @@ -219,7 +215,7 @@ impl Client { ); // Third, if the socket got closed, reconnect it. - if matches!(socket.state(), tcp::State::Closed) && !pending_queries.is_empty() { + if matches!(socket.state(), l3_tcp::State::Closed) && !pending_queries.is_empty() { let local_port = self .local_ports_by_socket .get(handle) @@ -248,7 +244,7 @@ impl Client { } pub fn poll_timeout(&mut self) -> Option { - let now = smol_now(self.created_at, self.last_now); + let now = l3_tcp::now(self.created_at, self.last_now); let poll_in = self.interface.poll_delay(now, &self.sockets)?; @@ -303,7 +299,7 @@ impl Client { } fn send_pending_queries( - socket: &mut tcp::Socket, + socket: &mut l3_tcp::Socket, server: SocketAddr, pending_queries: &mut VecDeque, sent_queries: &mut HashMap, @@ -339,7 +335,7 @@ fn send_pending_queries( } fn recv_responses( - socket: &mut tcp::Socket, + socket: &mut l3_tcp::Socket, server: SocketAddr, pending_queries: &mut VecDeque, sent_queries: &mut HashMap, @@ -398,7 +394,7 @@ fn into_failed_results( }) } -fn try_recv_response(socket: &mut tcp::Socket) -> Result> { +fn try_recv_response(socket: &mut l3_tcp::Socket) -> Result> { if !socket.can_recv() { tracing::trace!(state = %socket.state(), "Not yet ready to receive next message"); diff --git a/rust/connlib/dns-over-tcp/src/codec.rs b/rust/connlib/dns-over-tcp/src/codec.rs index da95b0e71..d7b7ce4e0 100644 --- a/rust/connlib/dns-over-tcp/src/codec.rs +++ b/rust/connlib/dns-over-tcp/src/codec.rs @@ -6,9 +6,8 @@ //! Source: . use anyhow::{Context as _, Result}; -use smoltcp::socket::tcp; -pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> { +pub fn try_send(socket: &mut l3_tcp::Socket, message: &[u8]) -> Result<()> { let dns_message_length = (message.len() as u16).to_be_bytes(); let written = socket @@ -51,7 +50,7 @@ pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> { Ok(()) } -pub fn try_recv<'b, M>(socket: &'b mut tcp::Socket) -> Result> +pub fn try_recv<'b, M>(socket: &'b mut l3_tcp::Socket) -> Result> where M: TryFrom<&'b [u8], Error: std::error::Error + Send + Sync + 'static>, { diff --git a/rust/connlib/dns-over-tcp/src/lib.rs b/rust/connlib/dns-over-tcp/src/lib.rs index 7e4872141..e5f786dff 100644 --- a/rust/connlib/dns-over-tcp/src/lib.rs +++ b/rust/connlib/dns-over-tcp/src/lib.rs @@ -1,22 +1,6 @@ mod client; mod codec; -mod interface; mod server; -mod stub_device; -mod time; pub use client::{Client, QueryResult}; pub use server::{Query, Server}; - -fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> { - /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. - /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. - /// Being able to buffer at least one of them means we can handle the extreme case. - /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. - const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; - - smoltcp::socket::tcp::Socket::new( - smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), - ) -} diff --git a/rust/connlib/dns-over-tcp/src/server.rs b/rust/connlib/dns-over-tcp/src/server.rs index 8ea8097d9..3385dc5f5 100644 --- a/rust/connlib/dns-over-tcp/src/server.rs +++ b/rust/connlib/dns-over-tcp/src/server.rs @@ -4,16 +4,12 @@ use std::{ time::{Duration, Instant}, }; -use crate::{ - codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice, - time::smol_now, -}; +use crate::codec; use anyhow::{Context as _, Result}; use ip_packet::IpPacket; -use smoltcp::{ - iface::{Interface, PollResult, SocketHandle, SocketSet}, - socket::tcp, - wire::IpEndpoint, +use l3_tcp::{ + InMemoryDevice, Interface, IpEndpoint, PollResult, SocketHandle, SocketSet, create_interface, + create_tcp_socket, }; /// A sans-IO implementation of DNS-over-TCP server. @@ -158,7 +154,7 @@ impl Server { .remove(&(src, dst, response.id())) .context("No pending query found for message")?; - let socket = self.sockets.get_mut::(handle); + let socket = self.sockets.get_mut::(handle); codec::try_send(socket, &response.into_bytes(u16::MAX)) .inspect_err(|_| socket.abort()) // Abort socket on error. @@ -174,7 +170,7 @@ impl Server { self.last_now = now; let result = self.interface.poll( - smol_now(self.created_at, now), + l3_tcp::now(self.created_at, now), &mut self.device, &mut self.sockets, ); @@ -183,7 +179,7 @@ impl Server { return; } - for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() { + for (handle, l3_tcp::AnySocket::Tcp(socket)) in self.sockets.iter_mut() { let local = self.listen_endpoints.get(&handle).copied().unwrap(); let _guard = tracing::trace_span!("socket", %handle).entered(); @@ -215,7 +211,7 @@ impl Server { } pub fn poll_timeout(&mut self) -> Option { - let now = smol_now(self.created_at, self.last_now); + let now = l3_tcp::now(self.created_at, self.last_now); let poll_in = self.interface.poll_delay(now, &self.sockets)?; @@ -234,13 +230,13 @@ impl Server { } fn try_recv_query( - socket: &mut tcp::Socket, + socket: &mut l3_tcp::Socket, listen: SocketAddr, ) -> Result> { // smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket. // to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing. { - use smoltcp::socket::tcp::State::*; + use l3_tcp::State::*; if matches!(socket.state(), Closed | TimeWait | CloseWait) { tracing::debug!(state = %socket.state(), "Resetting socket to listen state"); diff --git a/rust/connlib/dns-over-tcp/src/time.rs b/rust/connlib/dns-over-tcp/src/time.rs deleted file mode 100644 index 8b97f1d26..000000000 --- a/rust/connlib/dns-over-tcp/src/time.rs +++ /dev/null @@ -1,8 +0,0 @@ -use std::time::Instant; - -/// Computes an instance of [`smoltcp::time::Instant`] based on a given starting point and the current time. -pub fn smol_now(boot: Instant, now: Instant) -> smoltcp::time::Instant { - let millis_since_startup = now.duration_since(boot).as_millis(); - - smoltcp::time::Instant::from_millis(millis_since_startup as i64) -} diff --git a/rust/connlib/l3-tcp/Cargo.toml b/rust/connlib/l3-tcp/Cargo.toml new file mode 100644 index 000000000..93f206c2a --- /dev/null +++ b/rust/connlib/l3-tcp/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "l3-tcp" +version = "0.1.0" +description = "The TCP protocol from an OSI-layer 3 perspective, i.e. on IP level." +edition = { workspace = true } +license = { workspace = true } + +[dependencies] +anyhow = { workspace = true } +ip-packet = { workspace = true } +smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] } +tracing = { workspace = true } + +[lints] +workspace = true diff --git a/rust/connlib/dns-over-tcp/src/interface.rs b/rust/connlib/l3-tcp/src/interface.rs similarity index 89% rename from rust/connlib/dns-over-tcp/src/interface.rs rename to rust/connlib/l3-tcp/src/interface.rs index 0e2c61c45..30fc7c382 100644 --- a/rust/connlib/dns-over-tcp/src/interface.rs +++ b/rust/connlib/l3-tcp/src/interface.rs @@ -32,8 +32,10 @@ pub fn create_interface(device: &mut InMemoryDevice) -> Interface { // Set our interface IPs. These are just dummies and don't show up anywhere! interface.update_ip_addrs(|ips| { - ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()).unwrap(); - ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()).unwrap(); + ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()) + .expect("should be a valid IPv4 CIDR"); + ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()) + .expect("should be a valid IPv6 CIDR"); }); // Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface. diff --git a/rust/connlib/l3-tcp/src/lib.rs b/rust/connlib/l3-tcp/src/lib.rs new file mode 100644 index 000000000..caa062d64 --- /dev/null +++ b/rust/connlib/l3-tcp/src/lib.rs @@ -0,0 +1,35 @@ +//! Abstractions for working with the TCP protocol from an OSI-layer 3 perspective, i.e. IP. +//! +//! This crate is very much work-in-progress. +//! The abstractions in here are intended to grow as we learn more about our needs for interacting with TCP. + +mod interface; +mod stub_device; + +pub use crate::interface::create_interface; +pub use crate::stub_device::InMemoryDevice; +pub use smoltcp::iface::{Interface, PollResult, SocketHandle, SocketSet}; +pub use smoltcp::socket::Socket as AnySocket; +pub use smoltcp::socket::tcp::{Socket, State}; +pub use smoltcp::time::{Duration, Instant}; +pub use smoltcp::wire::IpEndpoint; + +pub fn create_tcp_socket() -> Socket<'static> { + /// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX. + /// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages. + /// Being able to buffer at least one of them means we can handle the extreme case. + /// In practice, this allows the OS to queue multiple queries even if we can't immediately process them. + const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize; + + Socket::new( + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]), + ) +} + +/// Computes an instance of [`smoltcp::time::Instant`] based on a given starting point and the current time. +pub fn now(boot: std::time::Instant, now: std::time::Instant) -> Instant { + let millis_since_startup = now.duration_since(boot).as_millis(); + + Instant::from_millis(millis_since_startup as i64) +} diff --git a/rust/connlib/dns-over-tcp/src/stub_device.rs b/rust/connlib/l3-tcp/src/stub_device.rs similarity index 91% rename from rust/connlib/dns-over-tcp/src/stub_device.rs rename to rust/connlib/l3-tcp/src/stub_device.rs index cbae806c5..d677e0a0b 100644 --- a/rust/connlib/dns-over-tcp/src/stub_device.rs +++ b/rust/connlib/l3-tcp/src/stub_device.rs @@ -4,17 +4,17 @@ use ip_packet::{IpPacket, IpPacketBuf}; /// A in-memory device for [`smoltcp`] that is entirely backed by buffers. #[derive(Debug, Default)] -pub(crate) struct InMemoryDevice { +pub struct InMemoryDevice { inbound_packets: VecDeque, outbound_packets: VecDeque, } impl InMemoryDevice { - pub(crate) fn receive(&mut self, packet: IpPacket) { + pub fn receive(&mut self, packet: IpPacket) { self.inbound_packets.push_back(packet); } - pub(crate) fn next_send(&mut self) -> Option { + pub fn next_send(&mut self) -> Option { self.outbound_packets.pop_front() } } @@ -52,7 +52,7 @@ impl smoltcp::phy::Device for InMemoryDevice { } } -pub(crate) struct SmolTxToken<'a> { +pub struct SmolTxToken<'a> { outbound_packets: &'a mut VecDeque, } @@ -88,7 +88,7 @@ impl smoltcp::phy::TxToken for SmolTxToken<'_> { } } -pub(crate) struct SmolRxToken { +pub struct SmolRxToken { packet: IpPacket, } diff --git a/rust/connlib/snownet/src/allocation.rs b/rust/connlib/snownet/src/allocation.rs index c55fa050f..1f802e8a5 100644 --- a/rust/connlib/snownet/src/allocation.rs +++ b/rust/connlib/snownet/src/allocation.rs @@ -36,6 +36,7 @@ use stun_codec::{ use tracing::{Span, field}; const REQUEST_TIMEOUT: Duration = Duration::from_secs(1); +const REQUEST_MAX_ELAPSED: Duration = Duration::from_secs(8); /// How often to send a STUN binding request after the initial connection to the relay. /// @@ -272,7 +273,17 @@ impl Allocation { tracing::debug!("Refreshing allocation"); - self.authenticate_and_queue(make_refresh_request(self.software.clone()), None, now); + // By using the `REQUEST_TIMEOUT` for timeout and max_elapsed, we effectively only perform + // a single request. + // + // When pro-actively refreshing the allocation, we don't want to timeout after 8s but much earlier. + let backoff = backoff::new(now, REQUEST_TIMEOUT, REQUEST_TIMEOUT); + + self.authenticate_and_queue( + make_refresh_request(self.software.clone()), + Some(backoff), + now, + ); } #[tracing::instrument(level = "debug", skip_all, fields(%from, tid, method, class, rtt))] @@ -1075,7 +1086,7 @@ impl Allocation { backoff: Option, now: Instant, ) -> bool { - let backoff = backoff.unwrap_or(backoff::new(now, REQUEST_TIMEOUT)); + let backoff = backoff.unwrap_or(backoff::new(now, REQUEST_TIMEOUT, REQUEST_MAX_ELAPSED)); let id = message.transaction_id(); if backoff.is_expired(now) { diff --git a/rust/connlib/snownet/src/backoff.rs b/rust/connlib/snownet/src/backoff.rs index f29203353..01f6e2771 100644 --- a/rust/connlib/snownet/src/backoff.rs +++ b/rust/connlib/snownet/src/backoff.rs @@ -1,11 +1,11 @@ use std::time::{Duration, Instant}; const MULTIPLIER: f32 = 1.5; -const MAX_ELAPSED_TIME: Duration = Duration::from_secs(8); #[derive(Debug)] pub struct ExponentialBackoff { start_time: Instant, + max_elapsed: Duration, next_trigger: Instant, interval: Duration, } @@ -29,7 +29,7 @@ impl ExponentialBackoff { } pub(crate) fn is_expired(&self, at: Instant) -> bool { - at >= self.start_time + MAX_ELAPSED_TIME + at >= self.start_time + self.max_elapsed } pub(crate) fn interval(&self) -> Duration { @@ -41,10 +41,11 @@ impl ExponentialBackoff { } } -pub fn new(now: Instant, interval: Duration) -> ExponentialBackoff { +pub fn new(now: Instant, interval: Duration, max_elapsed: Duration) -> ExponentialBackoff { ExponentialBackoff { interval, start_time: now, + max_elapsed, next_trigger: now + interval, } } @@ -77,7 +78,7 @@ mod tests { let steps = Vec::from_iter( iter::from_fn({ - let mut backoff = super::new(now, Duration::from_secs(1)); + let mut backoff = super::new(now, Duration::from_secs(1), Duration::from_secs(8)); move || { if backoff.is_expired(now) { diff --git a/rust/connlib/tunnel/Cargo.toml b/rust/connlib/tunnel/Cargo.toml index 74c75aae2..99d0c9f42 100644 --- a/rust/connlib/tunnel/Cargo.toml +++ b/rust/connlib/tunnel/Cargo.toml @@ -56,6 +56,7 @@ uuid = { workspace = true, features = ["std", "v4"] } [dev-dependencies] firezone-relay = { workspace = true, features = ["proptest"] } ip-packet = { workspace = true, features = ["proptest"] } +l3-tcp = { workspace = true } proptest-state-machine = { workspace = true } rand = { workspace = true } sha2 = { workspace = true } diff --git a/rust/connlib/tunnel/proptest-regressions/tests.txt b/rust/connlib/tunnel/proptest-regressions/tests.txt index 0b6d120c3..00192877c 100644 --- a/rust/connlib/tunnel/proptest-regressions/tests.txt +++ b/rust/connlib/tunnel/proptest-regressions/tests.txt @@ -167,3 +167,12 @@ cc 36a7bb4eff285399b9c431675d4337712e7edf016a3a02b05cba5115c8bf8fe4 cc 235333b8c818e464ba339e8c73b2467894d68d594ac896c4f6a36b25ac6b823d cc 436afa9076f65f9abbe801ef2a7f26631e433650a6f717358972f37a1fbf1542 cc ee518414c1632fb9d49272b985476de0d9de2786cadef997ad7d626e1a4b975a +cc b5ba38b054ffa7eb0e5687d69d6ef0d48c7bbcb60b4e8c8aa30fbc2338e5adcb +cc 3ff12104b0e754383c7d118363274c3a2a3d5493f985d6736338aea72ef795cf +cc 2c6eb0aa6c94363c27034ca3318ad85ed51fd6fefeb1f5b65b8c60bd8c6d381d +cc e281e909d1204d9891afc01b8f70eeb1db74938e7256dc2601333eec1175b59e +cc ee946b209f553b29b8a6ae2b71959c99c926328bc43bf8e213cd2f49e938fb70 +cc 7ab081a00991a3265b2ca82f2203284759bc50ef2805e5514baa0c24c966a580 +cc 9cac073e45583d9940fd8813b93c4cadea91c5d304c454ab8d050b44ba49dc13 +cc 608f3ed9392aa067bc730538d75f3692edf2ad5c3fa98beb3e95b166e04f7b5f +cc 57c9d6263fdae8b6bb51fbb7108372c7d695d1186163fcfcdce010a6666c3db5 diff --git a/rust/connlib/tunnel/src/peer.rs b/rust/connlib/tunnel/src/peer.rs index 1117c88e0..48c9fd36e 100644 --- a/rust/connlib/tunnel/src/peer.rs +++ b/rust/connlib/tunnel/src/peer.rs @@ -363,6 +363,8 @@ impl ClientOnGateway { } let Some(state) = self.permanent_translations.get_mut(&packet.destination()) else { + tracing::debug!(%dst, "No translation entry"); + return Ok(TranslateOutboundResult::DestinationUnreachable( ip_packet::make::icmp_dest_unreachable( &packet, @@ -373,6 +375,12 @@ impl ClientOnGateway { }; if state.resolved_ip.is_ipv4() != dst.is_ipv4() { + tracing::debug!( + %dst, + resolved = %state.resolved_ip, + "Cannot translate between IP versions" + ); + return Ok(TranslateOutboundResult::DestinationUnreachable( ip_packet::make::icmp_dest_unreachable( &packet, diff --git a/rust/connlib/tunnel/src/tests.rs b/rust/connlib/tunnel/src/tests.rs index 6d7bc6518..764202a29 100644 --- a/rust/connlib/tunnel/src/tests.rs +++ b/rust/connlib/tunnel/src/tests.rs @@ -27,6 +27,7 @@ mod sim_relay; mod strategies; mod stub_portal; mod sut; +mod tcp; mod transition; mod unreachable_hosts; diff --git a/rust/connlib/tunnel/src/tests/assertions.rs b/rust/connlib/tunnel/src/tests/assertions.rs index a1b9dfce0..264e6bded 100644 --- a/rust/connlib/tunnel/src/tests/assertions.rs +++ b/rust/connlib/tunnel/src/tests/assertions.rs @@ -11,7 +11,7 @@ use std::{ collections::{BTreeMap, HashMap, VecDeque, hash_map::Entry}, hash::Hash, marker::PhantomData, - net::IpAddr, + net::{IpAddr, SocketAddr}, sync::atomic::{AtomicBool, Ordering}, }; use tracing::{Level, Span, Subscriber}; @@ -75,42 +75,71 @@ pub(crate) fn assert_udp_packets_properties( ); } -/// Asserts the following properties for all TCP handshakes: -/// 1. An TCP request on the client MUST result in an TCP response using the flipped src & dst IP and sport and dport. -/// 2. An TCP request on the gateway MUST target the intended resource: -/// - For CIDR resources, that is the actual CIDR resource IP. -/// - For DNS resources, the IP must match one of the resolved IPs for the domain. -/// 3. For DNS resources, the mapping of proxy IP to actual resource IP must be stable. -pub(crate) fn assert_tcp_packets_properties( - ref_client: &RefClient, - sim_client: &SimClient, - sim_gateways: &BTreeMap, - global_dns_records: &DnsRecords, -) { - let received_tcp_requests = sim_gateways - .iter() - .map(|(g, s)| (*g, &s.received_tcp_requests)) - .collect(); +pub(crate) fn assert_tcp_connections(ref_client: &RefClient, sim_client: &SimClient) { + for (src, _, sport, dport) in ref_client.expected_tcp_connections.keys() { + let src = SocketAddr::new(*src, sport.0); + let received_icmp_error_for_tuple = sim_client + .failed_tcp_packets + .contains_key(&(*sport, *dport)); - assert_packets_properties( - ref_client, - &sim_client.sent_tcp_requests, - &received_tcp_requests, - &ref_client.expected_tcp_exchanges, - &sim_client.received_tcp_replies, - "TCP", - global_dns_records, - |sport, dport| tracing::info_span!(target: "assertions", "TCP", ?sport, ?dport), - ); + let Some((socket, local)) = sim_client.tcp_client.iter_sockets().find_map(|s| { + let endpoint = s.local_endpoint()?; + + (l3_tcp::IpEndpoint::from(src) == endpoint).then_some((s, endpoint)) + }) else { + // If we received an ICMP error for this port tuple, not having a socket is okay. + if received_icmp_error_for_tuple { + continue; + } + + tracing::error!(target: "assertions", %src, "Missing TCP connection"); + continue; + }; + let Some(remote) = socket.remote_endpoint() else { + tracing::error!(target: "assertions", %src, "TCP socket does not have a remote endpoint"); + continue; + }; + + let port = remote.port; + + if port == dport.0 { + tracing::info!(target: "assertions", %port, "TCP connection is targeting expected port"); + } else { + tracing::error!(target: "assertions", expected = %dport.0, actual = %port, "TCP connection dst port does not match"); + } + + let actual = socket.state(); + let expected = l3_tcp::State::Established; + + if actual == expected { + tracing::info!(target: "assertions", %local, %remote, "TCP connection is {expected}"); + } else { + tracing::error!(target: "assertions", %actual, %local, %remote, "TCP connection is not {expected}"); + } + + if received_icmp_error_for_tuple { + tracing::error!(target: "assertions", %local, %remote, "TCP socket should have been reset from ICMP error"); + } + } } pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimClient) { - let expected_status_map = &ref_client.expected_resource_status(); + use connlib_model::ResourceStatus::*; + + let (expected_status_map, tcp_resources) = &ref_client + .expected_resource_status(|tuple| sim_client.failed_tcp_packets.contains_key(&tuple)); let actual_status_map = &sim_client.resource_status; if expected_status_map != actual_status_map { for (resource, expected_status) in expected_status_map { match actual_status_map.get(resource) { + // For resources with TCP connections, the expected status might be off. + // The TCP client sends its own keep-alive's so we cannot always track the internal connection state. + Some(&Online) + if expected_status == &Unknown && tcp_resources.contains(resource) => {} + Some(&Unknown) + if expected_status == &Online && tcp_resources.contains(resource) => {} + Some(actual_status) if actual_status != expected_status => { tracing::error!(target: "assertions", %expected_status, %actual_status, %resource, "Resource status doesn't match"); } diff --git a/rust/connlib/tunnel/src/tests/reference.rs b/rust/connlib/tunnel/src/tests/reference.rs index 4dfb44e3d..ecb16d118 100644 --- a/rust/connlib/tunnel/src/tests/reference.rs +++ b/rust/connlib/tunnel/src/tests/reference.rs @@ -37,6 +37,9 @@ pub(crate) struct ReferenceState { /// This is used to e.g. mock DNS resolution on the gateway. pub(crate) global_dns_records: DnsRecords, + /// DNS Resources that listen for TCP connections. + pub(crate) tcp_resources: BTreeMap>, + /// A subset of all DNS resource records that have been selected to produce an ICMP error. pub(crate) unreachable_hosts: UnreachableHosts, @@ -74,7 +77,7 @@ impl ReferenceState { client, gateways, portal, - records, + dns_resource_records, relays, global_dns, drop_direct_client_traffic, @@ -83,8 +86,32 @@ impl ReferenceState { Just(client), Just(gateways), Just(portal), - Just(records.clone()), - unreachable_hosts(records), + Just(dns_resource_records.clone()), + unreachable_hosts(dns_resource_records), + Just(relays), + Just(global_dns), + Just(drop_direct_client_traffic), + ) + }, + ) + .prop_flat_map( + |( + client, + gateways, + portal, + dns_resource_records, + unreachable_hosts, + relays, + global_dns, + drop_direct_client_traffic, + )| { + ( + Just(client), + Just(gateways), + Just(portal), + Just(dns_resource_records.clone()), + Just(unreachable_hosts.clone()), + tcp_resources(dns_resource_records, unreachable_hosts), Just(relays), Just(global_dns), Just(drop_direct_client_traffic), @@ -99,6 +126,7 @@ impl ReferenceState { portal, records, unreachable_hosts, + tcp_resources, relays, mut global_dns, drop_direct_client_traffic, @@ -130,6 +158,7 @@ impl ReferenceState { portal, global_dns, unreachable_hosts, + tcp_resources, drop_direct_client_traffic, routing_table, )) @@ -137,7 +166,7 @@ impl ReferenceState { ) .prop_filter( "private keys must be unique", - |(c, gateways, _, _, _, _, _, _)| { + |(c, gateways, _, _, _, _, _, _, _)| { let different_keys = gateways .iter() .map(|(_, g)| g.inner().key) @@ -155,6 +184,7 @@ impl ReferenceState { portal, global_dns_records, unreachable_hosts, + tcp_resources, drop_direct_client_traffic, network, )| { @@ -167,6 +197,7 @@ impl ReferenceState { unreachable_hosts, network, drop_direct_client_traffic, + tcp_resources, } }, ) @@ -228,7 +259,6 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), udp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), - tcp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)), ] }, ) @@ -241,7 +271,6 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), udp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), - tcp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)), ] }, ) @@ -253,8 +282,7 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())), - udp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())), - tcp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains)), + udp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains)), ] }, ) @@ -266,11 +294,28 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),), - udp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),), - tcp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains),), + udp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains),), ] }, ) + .with_if_not_empty( + 10, + state.resolved_v4_domains_with_tcp_resources(), + |dns_v4_domains| { + let tunnel_ip4 = state.client.inner().tunnel_ip4; + + connect_tcp(Just(tunnel_ip4), select(dns_v4_domains)) + }, + ) + .with_if_not_empty( + 10, + state.resolved_v6_domains_with_tcp_resources(), + |dns_v6_domains| { + let tunnel_ip6 = state.client.inner().tunnel_ip6; + + connect_tcp(Just(tunnel_ip6), select(dns_v6_domains)) + }, + ) .with_if_not_empty( 5, (state.all_domains(), state.reachable_dns_servers()), @@ -294,10 +339,6 @@ impl ReferenceState { select(resolved_non_resource_ip4s.clone()), ), udp_packet( - packet_source_v4(tunnel_ip4), - select(resolved_non_resource_ip4s.clone()), - ), - tcp_packet( packet_source_v4(tunnel_ip4), select(resolved_non_resource_ip4s), ), @@ -319,10 +360,6 @@ impl ReferenceState { select(resolved_non_resource_ip6s.clone()), ), udp_packet( - packet_source_v6(tunnel_ip6), - select(resolved_non_resource_ip6s.clone()), - ), - tcp_packet( packet_source_v6(tunnel_ip6), select(resolved_non_resource_ip6s), ), @@ -335,7 +372,6 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)), udp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)), - tcp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)), ] }) .with_if_not_empty(1, state.connected_gateway_ipv6_ips(), |gateway_ips| { @@ -344,7 +380,6 @@ impl ReferenceState { prop_oneof![ icmp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)), udp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)), - tcp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)), ] }) .boxed() @@ -427,24 +462,14 @@ impl ReferenceState { ) }); } - Transition::SendTcpPayload { + Transition::ConnectTcp { + src, dst, sport, dport, - payload, - .. - } => { - state.client.exec_mut(|client| { - client.on_tcp_packet( - dst.clone(), - *sport, - *dport, - *payload, - |r| state.portal.gateway_for_resource(r).copied(), - |ip| state.portal.gateway_by_ip(ip), - ) - }); - } + } => state.client.exec_mut(|client| { + client.on_connect_tcp(*src, dst.clone(), *sport, *dport); + }), Transition::UpdateSystemDnsServers(servers) => { state .client @@ -507,13 +532,18 @@ impl ReferenceState { true } - Transition::DisableResources(resources) => { - // Don't disabled resources we don't have. - // It doesn't hurt but makes the logs of reduced testcases weird. - resources - .iter() - .all(|r| state.client.inner().has_resource(*r)) - } + Transition::DisableResources(resources) => resources.iter().all(|r| { + let has_resource = state.client.inner().has_resource(*r); + let has_tcp_connection = state + .client + .inner() + .tcp_connection_tuple_to_resource(*r) + .is_some(); + + // Don't disabled resources we don't have. It doesn't hurt but makes the logs of reduced testcases weird. + // Also don't disable resources where we have TCP connections as those would get interrupted. + has_resource && !has_tcp_connection + }), Transition::SendIcmpPacket { src, dst: Destination::DomainName { name, .. }, @@ -538,17 +568,16 @@ impl ReferenceState { ref_client.is_valid_udp_packet(sport, dport, payload) && state.is_valid_dst_domain(name, src) } - Transition::SendTcpPayload { + Transition::ConnectTcp { src, - dst: Destination::DomainName { name, .. }, + dst: dst @ Destination::DomainName { name, .. }, sport, dport, - payload, } => { let ref_client = state.client.inner(); - ref_client.is_valid_tcp_packet(sport, dport, payload) - && state.is_valid_dst_domain(name, src) + state.is_valid_dst_domain(name, src) + && !ref_client.has_tcp_connection(*src, dst.clone(), *sport, *dport) } Transition::SendIcmpPacket { dst: Destination::IpAddr(dst), @@ -573,16 +602,17 @@ impl ReferenceState { ref_client.is_valid_udp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst) } - Transition::SendTcpPayload { - dst: Destination::IpAddr(dst), + Transition::ConnectTcp { + src, + dst: dst @ Destination::IpAddr(dst_ip), sport, dport, - payload, .. } => { let ref_client = state.client.inner(); - ref_client.is_valid_tcp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst) + state.is_valid_dst_ip(*dst_ip) + && !ref_client.has_tcp_connection(*src, dst.clone(), *sport, *dport) } Transition::UpdateSystemDnsServers(servers) => { if servers.is_empty() { @@ -647,7 +677,16 @@ impl ReferenceState { } Transition::ReconnectPortal => true, Transition::DeactivateResource(r) => { - state.client.inner().all_resource_ids().contains(r) + let has_resource = state.client.inner().has_resource(*r); + let has_tcp_connection = state + .client + .inner() + .tcp_connection_tuple_to_resource(*r) + .is_some(); + + // Don't deactivate resources we don't have. It doesn't hurt but makes the logs of reduced testcases weird. + // Also don't deactivate resources where we have TCP connections as those would get interrupted. + has_resource && !has_tcp_connection } Transition::RebootRelaysWhilePartitioned(new_relays) | Transition::DeployNewRelays(new_relays) => { @@ -779,6 +818,24 @@ impl ReferenceState { .collect() } + fn resolved_v4_domains_with_tcp_resources(&self) -> Vec { + self.client + .inner() + .resolved_v4_domains() + .into_iter() + .filter(|domain| self.tcp_resources.contains_key(domain)) + .collect() + } + + fn resolved_v6_domains_with_tcp_resources(&self) -> Vec { + self.client + .inner() + .resolved_v6_domains() + .into_iter() + .filter(|domain| self.tcp_resources.contains_key(domain)) + .collect() + } + fn deploy_new_relays(&mut self, new_relays: &BTreeMap>) { // Always take down all relays because we can't know which one was sampled for the connection. for relay in self.relays.values() { diff --git a/rust/connlib/tunnel/src/tests/sim_client.rs b/rust/connlib/tunnel/src/tests/sim_client.rs index 97f91a429..c8e2f0993 100644 --- a/rust/connlib/tunnel/src/tests/sim_client.rs +++ b/rust/connlib/tunnel/src/tests/sim_client.rs @@ -13,7 +13,7 @@ use crate::{ messages::{DnsServer, Interface}, }; use bimap::BiMap; -use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, SiteId}; +use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, Site, SiteId}; use dns_types::{DomainName, Query, RecordData, RecordType}; use ip_network::{IpNetwork, Ipv4Network, Ipv6Network}; use ip_network_table::IpNetworkTable; @@ -59,13 +59,14 @@ pub(crate) struct SimClient { pub(crate) sent_icmp_requests: HashMap<(Seq, Identifier), IpPacket>, pub(crate) received_icmp_replies: BTreeMap<(Seq, Identifier), IpPacket>, - pub(crate) sent_tcp_requests: HashMap<(SPort, DPort), IpPacket>, - pub(crate) received_tcp_replies: BTreeMap<(SPort, DPort), IpPacket>, - pub(crate) sent_udp_requests: HashMap<(SPort, DPort), IpPacket>, pub(crate) received_udp_replies: BTreeMap<(SPort, DPort), IpPacket>, pub(crate) tcp_dns_client: dns_over_tcp::Client, + + /// TCP connections to resources. + pub(crate) tcp_client: crate::tests::tcp::Client, + pub(crate) failed_tcp_packets: BTreeMap<(SPort, DPort), IpPacket>, } impl SimClient { @@ -84,8 +85,6 @@ impl SimClient { received_tcp_dns_responses: Default::default(), sent_icmp_requests: Default::default(), received_icmp_replies: Default::default(), - sent_tcp_requests: Default::default(), - received_tcp_replies: Default::default(), sent_udp_requests: Default::default(), received_udp_replies: Default::default(), ipv4_routes: Default::default(), @@ -93,6 +92,8 @@ impl SimClient { search_domain: Default::default(), resource_status: Default::default(), tcp_dns_client, + tcp_client: crate::tests::tcp::Client::new(now), + failed_tcp_packets: Default::default(), } } @@ -163,6 +164,15 @@ impl SimClient { } } + pub fn connect_tcp(&mut self, src: IpAddr, dst: IpAddr, sport: SPort, dport: DPort) { + let local = SocketAddr::new(src, sport.0); + let remote = SocketAddr::new(dst, dport.0); + + if let Err(e) = self.tcp_client.connect(local, remote) { + tracing::error!("TCP connect failed: {e:#}") + } + } + pub(crate) fn encapsulate( &mut self, packet: IpPacket, @@ -178,6 +188,21 @@ impl SimClient { Some(transmit) } + pub fn poll_outbound(&mut self) -> Option { + self.tcp_dns_client + .poll_outbound() + .or_else(|| self.tcp_client.poll_outbound()) + } + + pub fn handle_timeout(&mut self, now: Instant) { + self.tcp_dns_client.handle_timeout(now); + self.tcp_client.handle_timeout(now); + + if self.sut.poll_timeout().is_some_and(|t| t <= now) { + self.sut.handle_timeout(now) + } + } + fn update_sent_requests(&mut self, packet: &IpPacket) { if let Some(icmp) = packet.as_icmpv4() { if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() { @@ -195,24 +220,12 @@ impl SimClient { } } - if let Some(tcp) = packet.as_tcp() { - self.sent_tcp_requests.insert( - (SPort(tcp.source_port()), DPort(tcp.destination_port())), - packet.clone(), - ); - return; - } - if let Some(udp) = packet.as_udp() { self.sent_udp_requests.insert( (SPort(udp.source_port()), DPort(udp.destination_port())), packet.clone(), ); - - return; } - - tracing::error!("Sent a request with an unknown transport protocol"); } pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) { @@ -239,8 +252,11 @@ impl SimClient { .insert((SPort(dst), DPort(src)), packet); } Layer4Protocol::Tcp { src, dst } => { - self.received_tcp_replies - .insert((SPort(dst), DPort(src)), packet); + self.failed_tcp_packets + .insert((SPort(src), DPort(dst)), packet.clone()); + + // Allow the client to process the ICMP error. + self.tcp_client.handle_inbound(packet); } Layer4Protocol::Icmp { seq, id } => { self.received_icmp_replies @@ -290,11 +306,8 @@ impl SimClient { return; } - if let Some(tcp) = packet.as_tcp() { - self.received_tcp_replies.insert( - (SPort(tcp.source_port()), DPort(tcp.destination_port())), - packet.clone(), - ); + if self.tcp_client.accepts(&packet) { + self.tcp_client.handle_inbound(packet); return; } @@ -438,10 +451,9 @@ pub struct RefClient { pub(crate) expected_udp_handshakes: BTreeMap>, - /// The expected TCP exchanges. + /// The expected TCP connections. #[debug(skip)] - pub(crate) expected_tcp_exchanges: - BTreeMap>, + pub(crate) expected_tcp_connections: HashMap<(IpAddr, Destination, SPort, DPort), ResourceId>, /// The expected UDP DNS handshakes. #[debug(skip)] @@ -582,8 +594,28 @@ impl RefClient { } } - pub(crate) fn expected_resource_status(&self) -> BTreeMap { - self.resources + #[expect( + clippy::disallowed_methods, + reason = "We don't care about the ordering of the expected TCP connections." + )] + pub(crate) fn expected_resource_status( + &self, + has_failed_tcp_connection: impl Fn((SPort, DPort)) -> bool, + ) -> (BTreeMap, BTreeSet) { + let maybe_online_sites = self + .expected_tcp_connections + .iter() + .filter(|((_, _, sport, dport), _)| !has_failed_tcp_connection((*sport, *dport))) + .filter_map(|(_, resource)| self.site_for_resource(*resource)) + .flat_map(|site| { + self.resources + .iter() + .filter_map(move |r| r.sites().contains(&site).then_some(r.id())) + }) + .collect(); + + let resource_status = self + .resources .iter() .filter_map(|r| { let status = self @@ -594,7 +626,9 @@ impl RefClient { Some((r.id(), status)) }) - .collect() + .collect(); + + (resource_status, maybe_online_sites) } pub(crate) fn tunnel_ip_for(&self, dst: IpAddr) -> IpAddr { @@ -642,25 +676,6 @@ impl RefClient { ); } - pub(crate) fn on_tcp_packet( - &mut self, - dst: Destination, - sport: SPort, - dport: DPort, - payload: u64, - gateway_by_resource: impl Fn(ResourceId) -> Option, - gateway_by_ip: impl Fn(IpAddr) -> Option, - ) { - self.on_packet( - dst.clone(), - (dst, sport, dport), - |ref_client| &mut ref_client.expected_tcp_exchanges, - payload, - gateway_by_resource, - gateway_by_ip, - ); - } - #[tracing::instrument(level = "debug", skip_all, fields(dst, resource, gateway))] fn on_packet( &mut self, @@ -708,6 +723,25 @@ impl RefClient { .insert(payload, packet_id); } + pub(crate) fn on_connect_tcp( + &mut self, + src: IpAddr, + dst: Destination, + sport: SPort, + dport: DPort, + ) { + let Some(resource) = self.resource_by_dst(&dst) else { + tracing::warn!("Unknown resource"); + return; + }; + + self.connect_to_resource(resource, dst.clone()); + self.set_resource_online(resource); + + self.expected_tcp_connections + .insert((src, dst, sport, dport), resource); + } + fn connect_to_resource(&mut self, resource: ResourceId, destination: Destination) { match destination { Destination::DomainName { .. } => {} @@ -716,11 +750,7 @@ impl RefClient { } fn set_resource_online(&mut self, resource: ResourceId) { - let Some(Ok(site)) = self - .resources - .iter() - .find_map(|r| (r.id() == resource).then_some(r.site())) - else { + let Some(site) = self.site_for_resource(resource) else { tracing::error!(%resource, "Unknown resource or multi-site resource"); return; }; @@ -801,6 +831,17 @@ impl RefClient { self.connected_cidr_resources.contains(&id) } + fn site_for_resource(&self, resource: ResourceId) -> Option { + let site = self + .resources + .iter() + .find_map(|r| (r.id() == resource).then_some(r.site()))? + .ok()? + .clone(); + + Some(site) + } + pub(crate) fn active_internet_resource(&self) -> Option { self.internet_resource .filter(|r| !self.disabled_resources.contains(r)) @@ -866,15 +907,6 @@ impl RefClient { ) } - /// An TCP packet is valid if we didn't yet send an TCP packet with the same sport, dport and payload. - pub(crate) fn is_valid_tcp_packet(&self, sport: &SPort, dport: &DPort, payload: &u64) -> bool { - self.expected_tcp_exchanges.values().flatten().all( - |(existig_payload, (_, existing_sport, existing_dport))| { - existing_dport != dport && existing_sport != sport && existig_payload != payload - }, - ) - } - pub(crate) fn resolved_v4_domains(&self) -> Vec { self.resolved_domains() .filter_map(|(domain, records)| { @@ -1053,6 +1085,30 @@ impl RefClient { pub(crate) fn upstream_dns_resolvers(&self) -> Vec { self.upstream_dns_resolvers.clone() } + + pub(crate) fn has_tcp_connection( + &self, + src: IpAddr, + dst: Destination, + sport: SPort, + dport: DPort, + ) -> bool { + self.expected_tcp_connections + .contains_key(&(src, dst, sport, dport)) + } + + #[expect( + clippy::disallowed_methods, + reason = "Iteration order does not matter here." + )] + pub(crate) fn tcp_connection_tuple_to_resource( + &self, + resource: ResourceId, + ) -> Option<(SPort, DPort)> { + self.expected_tcp_connections + .iter() + .find_map(|((_, _, sport, dport), res)| (resource == *res).then_some((*sport, *dport))) + } } // This function only works on the tests because we are limited to resources with a single wildcard at the beginning of the resource. @@ -1138,7 +1194,7 @@ fn ref_client( connected_internet_resource: Default::default(), expected_icmp_handshakes: Default::default(), expected_udp_handshakes: Default::default(), - expected_tcp_exchanges: Default::default(), + expected_tcp_connections: Default::default(), expected_udp_dns_handshakes: Default::default(), expected_tcp_dns_handshakes: Default::default(), disabled_resources: Default::default(), diff --git a/rust/connlib/tunnel/src/tests/sim_gateway.rs b/rust/connlib/tunnel/src/tests/sim_gateway.rs index 5e1a67fb2..b701d00ec 100644 --- a/rust/connlib/tunnel/src/tests/sim_gateway.rs +++ b/rust/connlib/tunnel/src/tests/sim_gateway.rs @@ -15,7 +15,7 @@ use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket}; use proptest::prelude::*; use snownet::Transmit; use std::{ - collections::BTreeMap, + collections::{BTreeMap, BTreeSet}, iter, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Instant, @@ -32,19 +32,20 @@ pub(crate) struct SimGateway { /// The received UDP packets, indexed by our custom UDP payload. pub(crate) received_udp_requests: BTreeMap, - /// The received TCP packets, indexed by our custom TCP payload. - pub(crate) received_tcp_requests: BTreeMap, - site_specific_dns_records: DnsRecords, udp_dns_server_resources: BTreeMap, tcp_dns_server_resources: BTreeMap, + + tcp_resources: BTreeMap, } impl SimGateway { pub(crate) fn new( id: GatewayId, sut: GatewayState, + tcp_resources: BTreeSet, site_specific_dns_records: DnsRecords, + now: Instant, ) -> Self { Self { id, @@ -54,7 +55,17 @@ impl SimGateway { udp_dns_server_resources: Default::default(), tcp_dns_server_resources: Default::default(), received_udp_requests: Default::default(), - received_tcp_requests: Default::default(), + tcp_resources: tcp_resources + .into_iter() + .map(|address| { + let mut server = crate::tests::tcp::Server::new(now); + if let Err(e) = server.listen(address) { + tracing::error!(%address, "Failed to listen on address: {e}") + } + + (address, server) + }) + .collect(), } } @@ -113,9 +124,15 @@ impl SimGateway { std::iter::from_fn(|| server.poll_outbound()) }); + let tcp_resource_packets = self.tcp_resources.values_mut().flat_map(|server| { + server.handle_timeout(now); + + std::iter::from_fn(|| server.poll_outbound()) + }); udp_server_packets .chain(tcp_server_packets) + .chain(tcp_resource_packets) .filter_map(|packet| self.sut.handle_tun_input(packet, now).unwrap()) .collect() } @@ -203,6 +220,11 @@ impl SimGateway { if let Some(tcp) = packet.as_tcp() { let socket = SocketAddr::new(dst_ip, tcp.destination_port()); + if let Some(server) = self.tcp_resources.get_mut(&socket) { + server.handle_inbound(packet); + return None; + } + // NOTE: we can make this assumption because port 53 is excluded from non-dns query packets if let Some(server) = self.tcp_dns_server_resources.get_mut(&socket) { server.handle_input(packet); @@ -240,12 +262,6 @@ impl SimGateway { tracing::debug!(%packet_id, "Received UDP request"); self.received_udp_requests.insert(packet_id, packet.clone()); } - - if let Some(tcp) = packet.as_tcp() { - let packet_id = u64::from_be_bytes(*tcp.payload().first_chunk().unwrap()); - tracing::debug!(%packet_id, "Received TCP request"); - self.received_tcp_requests.insert(packet_id, packet.clone()); - } } fn handle_icmp_request( @@ -287,14 +303,19 @@ impl RefGateway { /// Initialize the [`GatewayState`]. /// /// This simulates receiving the `init` message from the portal. - pub(crate) fn init(self, id: GatewayId, now: Instant) -> SimGateway { + pub(crate) fn init( + self, + id: GatewayId, + tcp_resources: BTreeSet, + now: Instant, + ) -> SimGateway { let mut sut = GatewayState::new(self.key.0, now); // Cheating a bit here by reusing the key as seed. sut.update_tun_device(IpConfig { v4: self.tunnel_ip4, v6: self.tunnel_ip6, }); - SimGateway::new(id, sut, self.site_specific_dns_records) + SimGateway::new(id, sut, tcp_resources, self.site_specific_dns_records, now) } pub fn dns_records(&self) -> &DnsRecords { diff --git a/rust/connlib/tunnel/src/tests/strategies.rs b/rust/connlib/tunnel/src/tests/strategies.rs index f66ea4f78..0f4612372 100644 --- a/rust/connlib/tunnel/src/tests/strategies.rs +++ b/rust/connlib/tunnel/src/tests/strategies.rs @@ -1,4 +1,5 @@ use super::dns_records::DnsRecords; +use super::unreachable_hosts::UnreachableHosts; use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal}; use crate::client::{ CidrResource, DNS_SENTINELS_V4, DNS_SENTINELS_V6, DnsResource, IPV4_RESOURCES, IPV6_RESOURCES, @@ -7,7 +8,7 @@ use crate::client::{ use crate::messages::DnsServer; use crate::{IPV4_TUNNEL, IPV6_TUNNEL, proptest::*}; use connlib_model::{RelayId, Site}; -use dns_types::OwnedRecordData; +use dns_types::{DomainName, OwnedRecordData}; use ip_network::{Ipv4Network, Ipv6Network}; use itertools::Itertools; use prop::sample; @@ -148,6 +149,48 @@ pub(crate) fn stub_portal() -> impl Strategy { ) } +/// Samples a list of TCP resource addresses from the given DNS records. +/// +/// We sample at most 1 domain from the given records and create a [`SocketAddr`] +/// for _each_ IP that this domain resolves this. +/// This is equivalent for how one would deploy a service in the real world. +/// If `example.com` resolves to 4 IPs, an HTTP server needs to run on all 4 IPs on the same port. +/// +/// The port is sampled together with domain. +pub(crate) fn tcp_resources( + dns_records: DnsRecords, + unreachable_hosts: UnreachableHosts, +) -> impl Strategy>> { + let all_domains = dns_records.domains_iter().collect::>(); + + collection::btree_set( + (sample::select(all_domains.clone()), any::()), + 1..=all_domains.len(), + ) + .prop_map(move |domains| { + domains + .into_iter() + .filter(|(domain, _)| { + dns_records + .domain_ips_iter(domain) + .all(|ip| !unreachable_hosts.is_unreachable(ip)) + }) + .map({ + let dns_records = dns_records.clone(); + + move |(domain, port)| { + let addresses = dns_records + .domain_ips_iter(&domain) + .map(|address| SocketAddr::new(address, port)) + .collect::>(); + + (domain, addresses) + } + }) + .collect() + }) +} + fn create_internet_site(mut sites: BTreeSet) -> (Site, BTreeSet) { // Rebrand the first site as the Internet site. That way, we can guarantee to always have one. let mut internet_site = sites.pop_first().unwrap(); diff --git a/rust/connlib/tunnel/src/tests/sut.rs b/rust/connlib/tunnel/src/tests/sut.rs index 08d887e3d..cced4c19a 100644 --- a/rust/connlib/tunnel/src/tests/sut.rs +++ b/rust/connlib/tunnel/src/tests/sut.rs @@ -20,11 +20,10 @@ use bufferpool::BufferPool; use connlib_model::{ClientId, GatewayId, PublicKey, RelayId}; use dns_types::ResponseCode; use dns_types::prelude::*; -use ip_packet::make::TcpFlags; use rand::SeedableRng; use rand::distributions::DistString; use sha2::Digest; -use snownet::Transmit; +use snownet::{NoTurnServers, Transmit}; use std::iter; use std::{ collections::BTreeMap, @@ -63,7 +62,18 @@ impl TunnelTest { .iter() .map(|(gid, gateway)| { let gateway = gateway.map( - |ref_gateway, _, _| ref_gateway.init(*gid, flux_capacitor.now()), + |ref_gateway, _, _| { + ref_gateway.init( + *gid, + ref_state + .tcp_resources + .values() + .flatten() + .copied() + .collect(), + flux_capacitor.now(), + ) + }, debug_span!("gateway", %gid), ); @@ -186,28 +196,17 @@ impl TunnelTest { buffered_transmits.push_from(transmit, &state.client, now); } - Transition::SendTcpPayload { + Transition::ConnectTcp { src, dst, sport, dport, - payload, } => { let dst = address_from_destination(&dst, &state, &src); - let packet = ip_packet::make::tcp_packet( - src, - dst, - sport.0, - dport.0, - TcpFlags::default(), - payload.to_be_bytes().to_vec(), - ) - .unwrap(); - - let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now)); - - buffered_transmits.push_from(transmit, &state.client, now); + state + .client + .exec_mut(|sim| sim.connect_tcp(src, dst, sport, dport)); } Transition::SendDnsQueries(queries) => { for DnsQuery { @@ -360,12 +359,7 @@ impl TunnelTest { &sim_gateways, &ref_state.global_dns_records, ); - assert_tcp_packets_properties( - ref_client, - sim_client, - &sim_gateways, - &ref_state.global_dns_records, - ); + assert_tcp_connections(ref_client, sim_client); assert_udp_dns_packets_properties(ref_client, sim_client); assert_tcp_dns(ref_client, sim_client); assert_dns_servers_are_valid(ref_client, sim_client); @@ -422,7 +416,21 @@ impl TunnelTest { } if let Some(event) = self.client.exec_mut(|c| c.sut.poll_event()) { - self.on_client_event(self.client.inner().id, event, &ref_state.portal); + match self.on_client_event(self.client.inner().id, event, &ref_state.portal) { + Ok(()) => {} + Err(AuthorizeFlowError::Client(_)) => { + self.client.exec_mut(|c| { + c.update_relays(iter::empty(), self.relays.iter(), now); + }); + } + Err(AuthorizeFlowError::Gateway(_)) => { + for gateway in self.gateways.values_mut() { + gateway.exec_mut(|g| { + g.update_relays(iter::empty(), self.relays.iter(), now) + }); + } + } + }; continue; } if let Some(query) = self.client.exec_mut(|c| c.sut.poll_dns_queries()) { @@ -536,8 +544,6 @@ impl TunnelTest { ) { // Handle the TCP DNS client, i.e. simulate applications making TCP DNS queries. self.client.exec_mut(|c| { - c.tcp_dns_client.handle_timeout(now); - while let Some(result) = c.tcp_dns_client.poll_query_result() { match result.result { Ok(message) => { @@ -554,21 +560,17 @@ impl TunnelTest { } }); while let Some(transmit) = self.client.exec_mut(|c| { - let packet = c.tcp_dns_client.poll_outbound()?; + let packet = c.poll_outbound()?; c.encapsulate(packet, now) }) { buffered_transmits.push_from(transmit, &self.client, now) } + self.client.exec_mut(|c| c.handle_timeout(now)); - // Handle the client's `Transmit`s and timeout. + // Handle the client's `Transmit`s. while let Some(transmit) = self.client.poll_inbox(now) { self.client.exec_mut(|c| c.receive(transmit, now)) } - self.client.exec_mut(|c| { - if c.sut.poll_timeout().is_some_and(|t| t <= now) { - c.sut.handle_timeout(now) - } - }); // Handle all gateway `Transmit`s and timeouts. for (_, gateway) in self.gateways.iter_mut() { @@ -680,7 +682,12 @@ impl TunnelTest { } } - fn on_client_event(&mut self, src: ClientId, event: ClientEvent, portal: &StubPortal) { + fn on_client_event( + &mut self, + src: ClientId, + event: ClientEvent, + portal: &StubPortal, + ) -> Result<(), AuthorizeFlowError> { let now = self.flux_capacitor.now(); match event { @@ -694,7 +701,9 @@ impl TunnelTest { for candidate in candidates { g.sut.add_ice_candidate(src, candidate, now) } - }) + }); + + Ok(()) } ClientEvent::RemovedIceCandidates { candidates, @@ -706,7 +715,9 @@ impl TunnelTest { for candidate in candidates { g.sut.remove_ice_candidate(src, candidate, now) } - }) + }); + + Ok(()) } ClientEvent::ConnectionIntent { resource: resource_id, @@ -736,22 +747,29 @@ impl TunnelTest { now, ) }) - .unwrap(); - if let Err(e) = self.client.exec_mut(|c| { - c.sut.handle_flow_created( - resource_id, - gateway_id, - gateway_key, - gateway.inner().sut.tunnel_ip_config().unwrap(), - site_id, - preshared_key, - client_ice, - gateway_ice, - now, - ) - }) { - tracing::error!("{e:#}") - }; + .map_err(AuthorizeFlowError::Gateway)?; + self.client + .exec_mut(|c| { + c.sut.handle_flow_created( + resource_id, + gateway_id, + gateway_key, + gateway.inner().sut.tunnel_ip_config().unwrap(), + site_id, + preshared_key, + client_ice, + gateway_ice, + now, + ) + }) + .unwrap_or_else(|e| { + tracing::error!("{e:#}"); + + Ok(()) + }) + .map_err(AuthorizeFlowError::Client)?; + + Ok(()) } ClientEvent::ResourcesChanged { resources } => { @@ -761,6 +779,8 @@ impl TunnelTest { .map(|r| (r.id(), r.status())) .collect(); }); + + Ok(()) } ClientEvent::TunInterfaceUpdated(config) => { if self.client.inner().dns_mapping() == &config.dns_by_sentinel @@ -788,6 +808,8 @@ impl TunnelTest { c.ipv6_routes = config.ipv6_routes; c.search_domain = config.search_domain }); + + Ok(()) } } } @@ -843,6 +865,11 @@ impl TunnelTest { } } +enum AuthorizeFlowError { + Client(NoTurnServers), + Gateway(NoTurnServers), +} + fn address_from_destination(destination: &Destination, state: &TunnelTest, src: &IpAddr) -> IpAddr { match destination { Destination::DomainName { resolved_ip, name } => { diff --git a/rust/connlib/tunnel/src/tests/tcp.rs b/rust/connlib/tunnel/src/tests/tcp.rs new file mode 100644 index 000000000..d3cd3affa --- /dev/null +++ b/rust/connlib/tunnel/src/tests/tcp.rs @@ -0,0 +1,165 @@ +use std::{ + collections::BTreeMap, + net::SocketAddr, + time::{Duration, Instant}, +}; + +use anyhow::{Context, Result}; +use ip_packet::{IpPacket, Layer4Protocol}; +use l3_tcp::Socket; + +pub struct Client { + sockets: l3_tcp::SocketSet<'static>, + sockets_by_remote: BTreeMap, + device: l3_tcp::InMemoryDevice, + interface: l3_tcp::Interface, + + created_at: Instant, + last_now: Instant, +} + +pub struct Server { + sockets: l3_tcp::SocketSet<'static>, + listen_endpoints: BTreeMap, + device: l3_tcp::InMemoryDevice, + interface: l3_tcp::Interface, + + created_at: Instant, + last_now: Instant, +} + +impl Client { + pub fn new(now: Instant) -> Self { + let mut device = l3_tcp::InMemoryDevice::default(); + let interface = l3_tcp::create_interface(&mut device); + + Self { + sockets: l3_tcp::SocketSet::new(Vec::default()), + sockets_by_remote: Default::default(), + device, + interface, + created_at: now, + last_now: now, + } + } + + pub fn connect(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<()> { + anyhow::ensure!(!self.sockets_by_remote.contains_key(&remote)); + + let mut socket = l3_tcp::create_tcp_socket(); + socket + .connect(self.interface.context(), remote, local) + .context("Failed to create TCP connection")?; + + // A short keep-alive ensures we detect broken connections. + socket.set_keep_alive(Some(l3_tcp::Duration::from_secs(5))); + + // 30s is a common timeout for TCP connections. + socket.set_timeout(Some(l3_tcp::Duration::from_secs(30))); + + let handle = self.sockets.add(socket); + + self.sockets_by_remote.insert(remote, handle); + + Ok(()) + } + + pub fn accepts(&self, packet: &IpPacket) -> bool { + let Some(tcp) = packet.as_tcp() else { + return false; + }; + + self.sockets_by_remote + .contains_key(&SocketAddr::new(packet.source(), tcp.source_port())) + } + + pub fn handle_inbound(&mut self, packet: IpPacket) { + // TODO: Upstream ICMP error handling to `smoltcp`. + if let Ok(Some((failed_packet, _))) = packet.icmp_unreachable_destination() { + if let Layer4Protocol::Tcp { dst, .. } = failed_packet.layer4_protocol() { + if let Some(handle) = self + .sockets_by_remote + .get(&SocketAddr::new(failed_packet.dst(), dst)) + { + self.sockets.get_mut::(*handle).abort(); + } + } + } + + self.device.receive(packet); + } + + pub fn handle_timeout(&mut self, now: Instant) { + self.last_now = now; + + let _result = self.interface.poll( + l3_tcp::now(self.created_at, now), + &mut self.device, + &mut self.sockets, + ); + } + + pub fn poll_outbound(&mut self) -> Option { + self.device.next_send() + } + + pub fn _poll_timeout(&mut self) -> Option { + let now = l3_tcp::now(self.created_at, self.last_now); + + let poll_in = self.interface.poll_delay(now, &self.sockets)?; + + Some(self.last_now + Duration::from(poll_in)) + } + + pub fn iter_sockets(&self) -> impl Iterator { + self.sockets.iter().map(|(_, s)| match s { + l3_tcp::AnySocket::Tcp(socket) => socket, + }) + } +} + +impl Server { + pub fn new(now: Instant) -> Self { + let mut device = l3_tcp::InMemoryDevice::default(); + let interface = l3_tcp::create_interface(&mut device); + + Self { + sockets: l3_tcp::SocketSet::new(Vec::default()), + listen_endpoints: Default::default(), + device, + interface, + created_at: now, + last_now: now, + } + } + + pub fn listen(&mut self, address: SocketAddr) -> Result<()> { + let mut socket = l3_tcp::create_tcp_socket(); + socket + .listen(address) + .with_context(|| format!("Failed to listen on {address}"))?; + + let handle = self.sockets.add(socket); + self.listen_endpoints.insert(handle, address); + + Ok(()) + } + + pub fn handle_inbound(&mut self, packet: IpPacket) { + self.device.receive(packet); + } + + pub fn handle_timeout(&mut self, now: Instant) { + self.last_now = now; + + let _result = self.interface.poll( + l3_tcp::now(self.created_at, now), + &mut self.device, + &mut self.sockets, + ); + } + + pub fn poll_outbound(&mut self) -> Option { + self.device.next_send() + } +} diff --git a/rust/connlib/tunnel/src/tests/transition.rs b/rust/connlib/tunnel/src/tests/transition.rs index 0803019a5..2be47c957 100644 --- a/rust/connlib/tunnel/src/tests/transition.rs +++ b/rust/connlib/tunnel/src/tests/transition.rs @@ -12,6 +12,7 @@ use proptest::{prelude::*, sample}; use std::{ collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + num::NonZeroU16, }; /// The possible transitions of the state machine. @@ -39,13 +40,12 @@ pub(crate) enum Transition { dport: DPort, payload: u64, }, - /// Send an TCP payload to destination (IP resource, DNS resource or IP non-resource). - SendTcpPayload { + + ConnectTcp { src: IpAddr, dst: Destination, sport: SPort, dport: DPort, - payload: u64, }, /// Send a DNS query. @@ -125,6 +125,29 @@ pub(crate) enum Destination { IpAddr(IpAddr), } +impl Eq for Destination {} + +impl std::hash::Hash for Destination { + fn hash(&self, state: &mut H) { + match self { + Destination::DomainName { name, .. } => name.hash(state), + Destination::IpAddr(ip_addr) => ip_addr.hash(state), + } + } +} + +impl PartialEq for Destination { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::DomainName { name: l_name, .. }, Self::DomainName { name: r_name, .. }) => { + l_name == r_name + } + (Self::IpAddr(l0), Self::IpAddr(r0)) => l0 == r0, + _ => false, + } + } +} + impl Destination { pub(crate) fn ip_addr(&self) -> Option { match self { @@ -246,32 +269,28 @@ where ) } -#[expect(private_bounds)] -pub(crate) fn tcp_packet( +pub(crate) fn connect_tcp( src: impl Strategy, - dst: impl Strategy, + dst: impl Strategy, ) -> impl Strategy where I: Into, - D: Into, { ( src.prop_map(Into::into), - dst.prop_map(Into::into), - any::(), - non_dns_ports(), + dst, + any::().prop_map(|p| p.get()), + non_dns_ports().prop_filter("avoid zero port", |p| *p != 0), any::(), - any::(), ) - .prop_map(|(src, dst, sport, dport, resolved_ip, payload)| { - Transition::SendTcpPayload { + .prop_map( + |(src, name, sport, dport, resolved_ip)| Transition::ConnectTcp { src, - dst: dst.into_destination(resolved_ip), + dst: Destination::DomainName { resolved_ip, name }, sport: SPort(sport), dport: DPort(dport), - payload, - } - }) + }, + ) } fn non_dns_ports() -> impl Strategy { diff --git a/rust/connlib/tunnel/src/tests/unreachable_hosts.rs b/rust/connlib/tunnel/src/tests/unreachable_hosts.rs index 2fb8ce3c9..562a20b2e 100644 --- a/rust/connlib/tunnel/src/tests/unreachable_hosts.rs +++ b/rust/connlib/tunnel/src/tests/unreachable_hosts.rs @@ -16,6 +16,10 @@ impl UnreachableHosts { pub(crate) fn icmp_error_for_ip(&self, ip: IpAddr) -> Option { self.inner.get(&ip).copied() } + + pub(crate) fn is_unreachable(&self, ip: IpAddr) -> bool { + self.inner.contains_key(&ip) + } } /// Samples a subset of the provided DNS records which we will treat as "unreachable".