test(connlib): establish real TCP connections in proptests (#9814)

With this patch, we sample a list of DNS resources on each test run and
create a "TCP service" for each of their addresses. Using this list of
resources, we then change the `SendTcpPayload` transition to
`ConnectTcp` and establish TCP connections using `smoltcp` to these
services.

For now, we don't send any data on these connections but we do set the
keep-alive interval to 5s, meaning `smoltcp` itself will keep these
connections alive. We also set the timeout to 30s and after each
transition in a test-run, we assert that all TCP sockets are still in
their expected state:

- `ESTABLISHED` for most of them.
- `CLOSED` for all sockets where we ended up sampling an IPv4 address
but the DNS resource only supports IPv6 addresses (or vice-versa). In
these cases, we use the ICMP error to sent by the Gateway to assert that
the socket is `CLOSED`. Unfortunately, `smoltcp` currently does not
handle ICMP messages for its sockets, so we have to call `abort`
ourselves.

Overall, this should assert that regardless of whether we roam networks,
switch relays or do other kind of stuff with the underlying connection,
the tunneled TCP connection stays alive.

In order to make this work, I had to tweak the timeouts when we are
on-demand refreshing allocations. This only happens in one particular
case: When we are being given new relays by the portal, we refresh all
_other_ relays to make sure they are still present. In other words, all
relays that we didn't remove and didn't just add but still had in-memory
are refreshed. This is important for cases where we are
network-partitioned from the portal whilst relays are deployed or reset
their state otherwise. Instead of the previous 8s max elapsed time of
the exponential backoff like we have it for other requests, we now only
use a single message with a 1s timeout there. With the increased ICE
timeout of 15s, a TCP connection with a 30s timeout would otherwise not
survive such an event. This is because it takes the above mentioned 8s
for us to remove a non-functioning relay, all whilst trying to establish
a new connection (which also incurs its own ICE timeout then).

With the reduced timeout on the on-demand refresh of 1s, we detect the
disappeared relay much quicker and can immediately establish a new
connection via one of the new ones. As always with reduced timeouts,
this can create false-positives if the relay doesn't reply within 1s for
some reason.

Resolves: #9531
This commit is contained in:
Thomas Eizinger
2025-07-11 17:10:22 +02:00
committed by GitHub
parent 26cfab3b88
commit 55eaa7cdc7
28 changed files with 787 additions and 303 deletions

View File

@@ -100,7 +100,7 @@ jobs:
# Poor man's test coverage testing: Grep the generated logs for specific patterns / lines.
rg --count --no-ignore SendIcmpPacket "$TESTCASES_DIR"
rg --count --no-ignore SendUdpPacket "$TESTCASES_DIR"
rg --count --no-ignore SendTcpPayload "$TESTCASES_DIR"
rg --count --no-ignore ConnectTcp "$TESTCASES_DIR"
rg --count --no-ignore SendDnsQueries "$TESTCASES_DIR"
rg --count --no-ignore "Packet for DNS resource" "$TESTCASES_DIR"
rg --count --no-ignore "Packet for CIDR resource" "$TESTCASES_DIR"

13
rust/Cargo.lock generated
View File

@@ -2035,8 +2035,8 @@ dependencies = [
"futures",
"ip-packet",
"ip_network",
"l3-tcp",
"rand 0.8.5",
"smoltcp",
"tokio",
"tracing",
"tun",
@@ -2631,6 +2631,7 @@ dependencies = [
"ip_network",
"ip_network_table",
"itertools 0.14.0",
"l3-tcp",
"l4-tcp-dns-server",
"l4-udp-dns-server",
"lru",
@@ -4026,6 +4027,16 @@ dependencies = [
"selectors",
]
[[package]]
name = "l3-tcp"
version = "0.1.0"
dependencies = [
"anyhow",
"ip-packet",
"smoltcp",
"tracing",
]
[[package]]
name = "l4-tcp-dns-server"
version = "0.1.0"

View File

@@ -9,6 +9,7 @@ members = [
"connlib/dns-types",
"connlib/etherparse-ext",
"connlib/ip-packet",
"connlib/l3-tcp",
"connlib/l4-tcp-dns-server",
"connlib/l4-udp-dns-server",
"connlib/model",
@@ -96,6 +97,7 @@ jemallocator = "0.5.4"
jni = "0.21.1"
keyring = "3.6.2"
known-folders = "1.2.0"
l3-tcp = { path = "connlib/l3-tcp" }
l4-tcp-dns-server = { path = "connlib/l4-tcp-dns-server" }
l4-udp-dns-server = { path = "connlib/l4-udp-dns-server" }
libc = "0.2.174"

View File

@@ -10,8 +10,8 @@ anyhow = { workspace = true }
dns-types = { workspace = true }
firezone-logging = { workspace = true }
ip-packet = { workspace = true }
l3-tcp = { workspace = true }
rand = { workspace = true }
smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
tracing = { workspace = true }
[dev-dependencies]

View File

@@ -4,17 +4,13 @@ use std::{
time::{Duration, Instant},
};
use crate::{
codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice,
time::smol_now,
};
use crate::codec;
use anyhow::{Context as _, Result, anyhow, bail};
use ip_packet::IpPacket;
use rand::{Rng, SeedableRng, rngs::StdRng};
use smoltcp::{
iface::{Interface, PollResult, SocketSet},
socket::tcp,
use l3_tcp::{
InMemoryDevice, Interface, PollResult, SocketSet, create_interface, create_tcp_socket,
};
use rand::{Rng, SeedableRng, rngs::StdRng};
/// A sans-io DNS-over-TCP client.
///
@@ -32,8 +28,8 @@ pub struct Client<const MIN_PORT: u16 = 49152, const MAX_PORT: u16 = 65535> {
source_ips: Option<(Ipv4Addr, Ipv6Addr)>,
sockets: SocketSet<'static>,
sockets_by_remote: BTreeMap<SocketAddr, smoltcp::iface::SocketHandle>,
local_ports_by_socket: HashMap<smoltcp::iface::SocketHandle, u16>,
sockets_by_remote: BTreeMap<SocketAddr, l3_tcp::SocketHandle>,
local_ports_by_socket: HashMap<l3_tcp::SocketHandle, u16>,
/// Queries we should send to a DNS resolver.
pending_queries_by_remote: HashMap<SocketAddr, VecDeque<dns_types::Query>>,
/// Queries we have sent to a DNS resolver and are waiting for a reply.
@@ -182,7 +178,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
};
let result = self.interface.poll(
smol_now(self.created_at, now),
l3_tcp::now(self.created_at, now),
&mut self.device,
&mut self.sockets,
);
@@ -194,7 +190,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
for (remote, handle) in self.sockets_by_remote.iter_mut() {
let _guard = tracing::trace_span!("socket", %handle).entered();
let socket = self.sockets.get_mut::<tcp::Socket>(*handle);
let socket = self.sockets.get_mut::<l3_tcp::Socket>(*handle);
let server = *remote;
let pending_queries = self.pending_queries_by_remote.entry(server).or_default();
@@ -219,7 +215,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
);
// Third, if the socket got closed, reconnect it.
if matches!(socket.state(), tcp::State::Closed) && !pending_queries.is_empty() {
if matches!(socket.state(), l3_tcp::State::Closed) && !pending_queries.is_empty() {
let local_port = self
.local_ports_by_socket
.get(handle)
@@ -248,7 +244,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
let now = smol_now(self.created_at, self.last_now);
let now = l3_tcp::now(self.created_at, self.last_now);
let poll_in = self.interface.poll_delay(now, &self.sockets)?;
@@ -303,7 +299,7 @@ impl<const MIN_PORT: u16, const MAX_PORT: u16> Client<MIN_PORT, MAX_PORT> {
}
fn send_pending_queries(
socket: &mut tcp::Socket,
socket: &mut l3_tcp::Socket,
server: SocketAddr,
pending_queries: &mut VecDeque<dns_types::Query>,
sent_queries: &mut HashMap<u16, dns_types::Query>,
@@ -339,7 +335,7 @@ fn send_pending_queries(
}
fn recv_responses(
socket: &mut tcp::Socket,
socket: &mut l3_tcp::Socket,
server: SocketAddr,
pending_queries: &mut VecDeque<dns_types::Query>,
sent_queries: &mut HashMap<u16, dns_types::Query>,
@@ -398,7 +394,7 @@ fn into_failed_results(
})
}
fn try_recv_response(socket: &mut tcp::Socket) -> Result<Option<dns_types::Response>> {
fn try_recv_response(socket: &mut l3_tcp::Socket) -> Result<Option<dns_types::Response>> {
if !socket.can_recv() {
tracing::trace!(state = %socket.state(), "Not yet ready to receive next message");

View File

@@ -6,9 +6,8 @@
//! Source: <https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2>.
use anyhow::{Context as _, Result};
use smoltcp::socket::tcp;
pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> {
pub fn try_send(socket: &mut l3_tcp::Socket, message: &[u8]) -> Result<()> {
let dns_message_length = (message.len() as u16).to_be_bytes();
let written = socket
@@ -51,7 +50,7 @@ pub fn try_send(socket: &mut tcp::Socket, message: &[u8]) -> Result<()> {
Ok(())
}
pub fn try_recv<'b, M>(socket: &'b mut tcp::Socket) -> Result<Option<M>>
pub fn try_recv<'b, M>(socket: &'b mut l3_tcp::Socket) -> Result<Option<M>>
where
M: TryFrom<&'b [u8], Error: std::error::Error + Send + Sync + 'static>,
{

View File

@@ -1,22 +1,6 @@
mod client;
mod codec;
mod interface;
mod server;
mod stub_device;
mod time;
pub use client::{Client, QueryResult};
pub use server::{Query, Server};
fn create_tcp_socket() -> smoltcp::socket::tcp::Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.
/// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages.
/// Being able to buffer at least one of them means we can handle the extreme case.
/// In practice, this allows the OS to queue multiple queries even if we can't immediately process them.
const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize;
smoltcp::socket::tcp::Socket::new(
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
)
}

View File

@@ -4,16 +4,12 @@ use std::{
time::{Duration, Instant},
};
use crate::{
codec, create_tcp_socket, interface::create_interface, stub_device::InMemoryDevice,
time::smol_now,
};
use crate::codec;
use anyhow::{Context as _, Result};
use ip_packet::IpPacket;
use smoltcp::{
iface::{Interface, PollResult, SocketHandle, SocketSet},
socket::tcp,
wire::IpEndpoint,
use l3_tcp::{
InMemoryDevice, Interface, IpEndpoint, PollResult, SocketHandle, SocketSet, create_interface,
create_tcp_socket,
};
/// A sans-IO implementation of DNS-over-TCP server.
@@ -158,7 +154,7 @@ impl Server {
.remove(&(src, dst, response.id()))
.context("No pending query found for message")?;
let socket = self.sockets.get_mut::<tcp::Socket>(handle);
let socket = self.sockets.get_mut::<l3_tcp::Socket>(handle);
codec::try_send(socket, &response.into_bytes(u16::MAX))
.inspect_err(|_| socket.abort()) // Abort socket on error.
@@ -174,7 +170,7 @@ impl Server {
self.last_now = now;
let result = self.interface.poll(
smol_now(self.created_at, now),
l3_tcp::now(self.created_at, now),
&mut self.device,
&mut self.sockets,
);
@@ -183,7 +179,7 @@ impl Server {
return;
}
for (handle, smoltcp::socket::Socket::Tcp(socket)) in self.sockets.iter_mut() {
for (handle, l3_tcp::AnySocket::Tcp(socket)) in self.sockets.iter_mut() {
let local = self.listen_endpoints.get(&handle).copied().unwrap();
let _guard = tracing::trace_span!("socket", %handle).entered();
@@ -215,7 +211,7 @@ impl Server {
}
pub fn poll_timeout(&mut self) -> Option<Instant> {
let now = smol_now(self.created_at, self.last_now);
let now = l3_tcp::now(self.created_at, self.last_now);
let poll_in = self.interface.poll_delay(now, &self.sockets)?;
@@ -234,13 +230,13 @@ impl Server {
}
fn try_recv_query(
socket: &mut tcp::Socket,
socket: &mut l3_tcp::Socket,
listen: SocketAddr,
) -> Result<Option<(dns_types::Query, SocketAddr)>> {
// smoltcp's sockets can only ever handle a single remote, i.e. there is no permanent listening socket.
// to be able to handle a new connection, reset the socket back to `listen` once the connection is closed / closing.
{
use smoltcp::socket::tcp::State::*;
use l3_tcp::State::*;
if matches!(socket.state(), Closed | TimeWait | CloseWait) {
tracing::debug!(state = %socket.state(), "Resetting socket to listen state");

View File

@@ -1,8 +0,0 @@
use std::time::Instant;
/// Computes an instance of [`smoltcp::time::Instant`] based on a given starting point and the current time.
pub fn smol_now(boot: Instant, now: Instant) -> smoltcp::time::Instant {
let millis_since_startup = now.duration_since(boot).as_millis();
smoltcp::time::Instant::from_millis(millis_since_startup as i64)
}

View File

@@ -0,0 +1,15 @@
[package]
name = "l3-tcp"
version = "0.1.0"
description = "The TCP protocol from an OSI-layer 3 perspective, i.e. on IP level."
edition = { workspace = true }
license = { workspace = true }
[dependencies]
anyhow = { workspace = true }
ip-packet = { workspace = true }
smoltcp = { workspace = true, features = ["std", "log", "medium-ip", "proto-ipv4", "proto-ipv6", "socket-tcp"] }
tracing = { workspace = true }
[lints]
workspace = true

View File

@@ -32,8 +32,10 @@ pub fn create_interface(device: &mut InMemoryDevice) -> Interface {
// Set our interface IPs. These are just dummies and don't show up anywhere!
interface.update_ip_addrs(|ips| {
ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into()).unwrap();
ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into()).unwrap();
ips.push(Ipv4Cidr::new(IP4_ADDR, 32).into())
.expect("should be a valid IPv4 CIDR");
ips.push(Ipv6Cidr::new(IP6_ADDR, 128).into())
.expect("should be a valid IPv6 CIDR");
});
// Configure catch-all routes, meaning all packets given to `smoltcp` will be routed to our interface.

View File

@@ -0,0 +1,35 @@
//! Abstractions for working with the TCP protocol from an OSI-layer 3 perspective, i.e. IP.
//!
//! This crate is very much work-in-progress.
//! The abstractions in here are intended to grow as we learn more about our needs for interacting with TCP.
mod interface;
mod stub_device;
pub use crate::interface::create_interface;
pub use crate::stub_device::InMemoryDevice;
pub use smoltcp::iface::{Interface, PollResult, SocketHandle, SocketSet};
pub use smoltcp::socket::Socket as AnySocket;
pub use smoltcp::socket::tcp::{Socket, State};
pub use smoltcp::time::{Duration, Instant};
pub use smoltcp::wire::IpEndpoint;
pub fn create_tcp_socket() -> Socket<'static> {
/// The 2-byte length prefix of DNS over TCP messages limits their size to effectively u16::MAX.
/// It is quite unlikely that we have to buffer _multiple_ of these max-sized messages.
/// Being able to buffer at least one of them means we can handle the extreme case.
/// In practice, this allows the OS to queue multiple queries even if we can't immediately process them.
const MAX_TCP_DNS_MSG_LENGTH: usize = u16::MAX as usize;
Socket::new(
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
smoltcp::storage::RingBuffer::new(vec![0u8; MAX_TCP_DNS_MSG_LENGTH]),
)
}
/// Computes an instance of [`smoltcp::time::Instant`] based on a given starting point and the current time.
pub fn now(boot: std::time::Instant, now: std::time::Instant) -> Instant {
let millis_since_startup = now.duration_since(boot).as_millis();
Instant::from_millis(millis_since_startup as i64)
}

View File

@@ -4,17 +4,17 @@ use ip_packet::{IpPacket, IpPacketBuf};
/// A in-memory device for [`smoltcp`] that is entirely backed by buffers.
#[derive(Debug, Default)]
pub(crate) struct InMemoryDevice {
pub struct InMemoryDevice {
inbound_packets: VecDeque<IpPacket>,
outbound_packets: VecDeque<IpPacket>,
}
impl InMemoryDevice {
pub(crate) fn receive(&mut self, packet: IpPacket) {
pub fn receive(&mut self, packet: IpPacket) {
self.inbound_packets.push_back(packet);
}
pub(crate) fn next_send(&mut self) -> Option<IpPacket> {
pub fn next_send(&mut self) -> Option<IpPacket> {
self.outbound_packets.pop_front()
}
}
@@ -52,7 +52,7 @@ impl smoltcp::phy::Device for InMemoryDevice {
}
}
pub(crate) struct SmolTxToken<'a> {
pub struct SmolTxToken<'a> {
outbound_packets: &'a mut VecDeque<IpPacket>,
}
@@ -88,7 +88,7 @@ impl smoltcp::phy::TxToken for SmolTxToken<'_> {
}
}
pub(crate) struct SmolRxToken {
pub struct SmolRxToken {
packet: IpPacket,
}

View File

@@ -36,6 +36,7 @@ use stun_codec::{
use tracing::{Span, field};
const REQUEST_TIMEOUT: Duration = Duration::from_secs(1);
const REQUEST_MAX_ELAPSED: Duration = Duration::from_secs(8);
/// How often to send a STUN binding request after the initial connection to the relay.
///
@@ -272,7 +273,17 @@ impl Allocation {
tracing::debug!("Refreshing allocation");
self.authenticate_and_queue(make_refresh_request(self.software.clone()), None, now);
// By using the `REQUEST_TIMEOUT` for timeout and max_elapsed, we effectively only perform
// a single request.
//
// When pro-actively refreshing the allocation, we don't want to timeout after 8s but much earlier.
let backoff = backoff::new(now, REQUEST_TIMEOUT, REQUEST_TIMEOUT);
self.authenticate_and_queue(
make_refresh_request(self.software.clone()),
Some(backoff),
now,
);
}
#[tracing::instrument(level = "debug", skip_all, fields(%from, tid, method, class, rtt))]
@@ -1075,7 +1086,7 @@ impl Allocation {
backoff: Option<ExponentialBackoff>,
now: Instant,
) -> bool {
let backoff = backoff.unwrap_or(backoff::new(now, REQUEST_TIMEOUT));
let backoff = backoff.unwrap_or(backoff::new(now, REQUEST_TIMEOUT, REQUEST_MAX_ELAPSED));
let id = message.transaction_id();
if backoff.is_expired(now) {

View File

@@ -1,11 +1,11 @@
use std::time::{Duration, Instant};
const MULTIPLIER: f32 = 1.5;
const MAX_ELAPSED_TIME: Duration = Duration::from_secs(8);
#[derive(Debug)]
pub struct ExponentialBackoff {
start_time: Instant,
max_elapsed: Duration,
next_trigger: Instant,
interval: Duration,
}
@@ -29,7 +29,7 @@ impl ExponentialBackoff {
}
pub(crate) fn is_expired(&self, at: Instant) -> bool {
at >= self.start_time + MAX_ELAPSED_TIME
at >= self.start_time + self.max_elapsed
}
pub(crate) fn interval(&self) -> Duration {
@@ -41,10 +41,11 @@ impl ExponentialBackoff {
}
}
pub fn new(now: Instant, interval: Duration) -> ExponentialBackoff {
pub fn new(now: Instant, interval: Duration, max_elapsed: Duration) -> ExponentialBackoff {
ExponentialBackoff {
interval,
start_time: now,
max_elapsed,
next_trigger: now + interval,
}
}
@@ -77,7 +78,7 @@ mod tests {
let steps = Vec::from_iter(
iter::from_fn({
let mut backoff = super::new(now, Duration::from_secs(1));
let mut backoff = super::new(now, Duration::from_secs(1), Duration::from_secs(8));
move || {
if backoff.is_expired(now) {

View File

@@ -56,6 +56,7 @@ uuid = { workspace = true, features = ["std", "v4"] }
[dev-dependencies]
firezone-relay = { workspace = true, features = ["proptest"] }
ip-packet = { workspace = true, features = ["proptest"] }
l3-tcp = { workspace = true }
proptest-state-machine = { workspace = true }
rand = { workspace = true }
sha2 = { workspace = true }

View File

@@ -167,3 +167,12 @@ cc 36a7bb4eff285399b9c431675d4337712e7edf016a3a02b05cba5115c8bf8fe4
cc 235333b8c818e464ba339e8c73b2467894d68d594ac896c4f6a36b25ac6b823d
cc 436afa9076f65f9abbe801ef2a7f26631e433650a6f717358972f37a1fbf1542
cc ee518414c1632fb9d49272b985476de0d9de2786cadef997ad7d626e1a4b975a
cc b5ba38b054ffa7eb0e5687d69d6ef0d48c7bbcb60b4e8c8aa30fbc2338e5adcb
cc 3ff12104b0e754383c7d118363274c3a2a3d5493f985d6736338aea72ef795cf
cc 2c6eb0aa6c94363c27034ca3318ad85ed51fd6fefeb1f5b65b8c60bd8c6d381d
cc e281e909d1204d9891afc01b8f70eeb1db74938e7256dc2601333eec1175b59e
cc ee946b209f553b29b8a6ae2b71959c99c926328bc43bf8e213cd2f49e938fb70
cc 7ab081a00991a3265b2ca82f2203284759bc50ef2805e5514baa0c24c966a580
cc 9cac073e45583d9940fd8813b93c4cadea91c5d304c454ab8d050b44ba49dc13
cc 608f3ed9392aa067bc730538d75f3692edf2ad5c3fa98beb3e95b166e04f7b5f
cc 57c9d6263fdae8b6bb51fbb7108372c7d695d1186163fcfcdce010a6666c3db5

View File

@@ -363,6 +363,8 @@ impl ClientOnGateway {
}
let Some(state) = self.permanent_translations.get_mut(&packet.destination()) else {
tracing::debug!(%dst, "No translation entry");
return Ok(TranslateOutboundResult::DestinationUnreachable(
ip_packet::make::icmp_dest_unreachable(
&packet,
@@ -373,6 +375,12 @@ impl ClientOnGateway {
};
if state.resolved_ip.is_ipv4() != dst.is_ipv4() {
tracing::debug!(
%dst,
resolved = %state.resolved_ip,
"Cannot translate between IP versions"
);
return Ok(TranslateOutboundResult::DestinationUnreachable(
ip_packet::make::icmp_dest_unreachable(
&packet,

View File

@@ -27,6 +27,7 @@ mod sim_relay;
mod strategies;
mod stub_portal;
mod sut;
mod tcp;
mod transition;
mod unreachable_hosts;

View File

@@ -11,7 +11,7 @@ use std::{
collections::{BTreeMap, HashMap, VecDeque, hash_map::Entry},
hash::Hash,
marker::PhantomData,
net::IpAddr,
net::{IpAddr, SocketAddr},
sync::atomic::{AtomicBool, Ordering},
};
use tracing::{Level, Span, Subscriber};
@@ -75,42 +75,71 @@ pub(crate) fn assert_udp_packets_properties(
);
}
/// Asserts the following properties for all TCP handshakes:
/// 1. An TCP request on the client MUST result in an TCP response using the flipped src & dst IP and sport and dport.
/// 2. An TCP request on the gateway MUST target the intended resource:
/// - For CIDR resources, that is the actual CIDR resource IP.
/// - For DNS resources, the IP must match one of the resolved IPs for the domain.
/// 3. For DNS resources, the mapping of proxy IP to actual resource IP must be stable.
pub(crate) fn assert_tcp_packets_properties(
ref_client: &RefClient,
sim_client: &SimClient,
sim_gateways: &BTreeMap<GatewayId, &SimGateway>,
global_dns_records: &DnsRecords,
) {
let received_tcp_requests = sim_gateways
.iter()
.map(|(g, s)| (*g, &s.received_tcp_requests))
.collect();
pub(crate) fn assert_tcp_connections(ref_client: &RefClient, sim_client: &SimClient) {
for (src, _, sport, dport) in ref_client.expected_tcp_connections.keys() {
let src = SocketAddr::new(*src, sport.0);
let received_icmp_error_for_tuple = sim_client
.failed_tcp_packets
.contains_key(&(*sport, *dport));
assert_packets_properties(
ref_client,
&sim_client.sent_tcp_requests,
&received_tcp_requests,
&ref_client.expected_tcp_exchanges,
&sim_client.received_tcp_replies,
"TCP",
global_dns_records,
|sport, dport| tracing::info_span!(target: "assertions", "TCP", ?sport, ?dport),
);
let Some((socket, local)) = sim_client.tcp_client.iter_sockets().find_map(|s| {
let endpoint = s.local_endpoint()?;
(l3_tcp::IpEndpoint::from(src) == endpoint).then_some((s, endpoint))
}) else {
// If we received an ICMP error for this port tuple, not having a socket is okay.
if received_icmp_error_for_tuple {
continue;
}
tracing::error!(target: "assertions", %src, "Missing TCP connection");
continue;
};
let Some(remote) = socket.remote_endpoint() else {
tracing::error!(target: "assertions", %src, "TCP socket does not have a remote endpoint");
continue;
};
let port = remote.port;
if port == dport.0 {
tracing::info!(target: "assertions", %port, "TCP connection is targeting expected port");
} else {
tracing::error!(target: "assertions", expected = %dport.0, actual = %port, "TCP connection dst port does not match");
}
let actual = socket.state();
let expected = l3_tcp::State::Established;
if actual == expected {
tracing::info!(target: "assertions", %local, %remote, "TCP connection is {expected}");
} else {
tracing::error!(target: "assertions", %actual, %local, %remote, "TCP connection is not {expected}");
}
if received_icmp_error_for_tuple {
tracing::error!(target: "assertions", %local, %remote, "TCP socket should have been reset from ICMP error");
}
}
}
pub(crate) fn assert_resource_status(ref_client: &RefClient, sim_client: &SimClient) {
let expected_status_map = &ref_client.expected_resource_status();
use connlib_model::ResourceStatus::*;
let (expected_status_map, tcp_resources) = &ref_client
.expected_resource_status(|tuple| sim_client.failed_tcp_packets.contains_key(&tuple));
let actual_status_map = &sim_client.resource_status;
if expected_status_map != actual_status_map {
for (resource, expected_status) in expected_status_map {
match actual_status_map.get(resource) {
// For resources with TCP connections, the expected status might be off.
// The TCP client sends its own keep-alive's so we cannot always track the internal connection state.
Some(&Online)
if expected_status == &Unknown && tcp_resources.contains(resource) => {}
Some(&Unknown)
if expected_status == &Online && tcp_resources.contains(resource) => {}
Some(actual_status) if actual_status != expected_status => {
tracing::error!(target: "assertions", %expected_status, %actual_status, %resource, "Resource status doesn't match");
}

View File

@@ -37,6 +37,9 @@ pub(crate) struct ReferenceState {
/// This is used to e.g. mock DNS resolution on the gateway.
pub(crate) global_dns_records: DnsRecords,
/// DNS Resources that listen for TCP connections.
pub(crate) tcp_resources: BTreeMap<DomainName, BTreeSet<SocketAddr>>,
/// A subset of all DNS resource records that have been selected to produce an ICMP error.
pub(crate) unreachable_hosts: UnreachableHosts,
@@ -74,7 +77,7 @@ impl ReferenceState {
client,
gateways,
portal,
records,
dns_resource_records,
relays,
global_dns,
drop_direct_client_traffic,
@@ -83,8 +86,32 @@ impl ReferenceState {
Just(client),
Just(gateways),
Just(portal),
Just(records.clone()),
unreachable_hosts(records),
Just(dns_resource_records.clone()),
unreachable_hosts(dns_resource_records),
Just(relays),
Just(global_dns),
Just(drop_direct_client_traffic),
)
},
)
.prop_flat_map(
|(
client,
gateways,
portal,
dns_resource_records,
unreachable_hosts,
relays,
global_dns,
drop_direct_client_traffic,
)| {
(
Just(client),
Just(gateways),
Just(portal),
Just(dns_resource_records.clone()),
Just(unreachable_hosts.clone()),
tcp_resources(dns_resource_records, unreachable_hosts),
Just(relays),
Just(global_dns),
Just(drop_direct_client_traffic),
@@ -99,6 +126,7 @@ impl ReferenceState {
portal,
records,
unreachable_hosts,
tcp_resources,
relays,
mut global_dns,
drop_direct_client_traffic,
@@ -130,6 +158,7 @@ impl ReferenceState {
portal,
global_dns,
unreachable_hosts,
tcp_resources,
drop_direct_client_traffic,
routing_table,
))
@@ -137,7 +166,7 @@ impl ReferenceState {
)
.prop_filter(
"private keys must be unique",
|(c, gateways, _, _, _, _, _, _)| {
|(c, gateways, _, _, _, _, _, _, _)| {
let different_keys = gateways
.iter()
.map(|(_, g)| g.inner().key)
@@ -155,6 +184,7 @@ impl ReferenceState {
portal,
global_dns_records,
unreachable_hosts,
tcp_resources,
drop_direct_client_traffic,
network,
)| {
@@ -167,6 +197,7 @@ impl ReferenceState {
unreachable_hosts,
network,
drop_direct_client_traffic,
tcp_resources,
}
},
)
@@ -228,7 +259,6 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)),
udp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)),
tcp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&ip4_resources)),
]
},
)
@@ -241,7 +271,6 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)),
udp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)),
tcp_packet(packet_source_v6(tunnel_ip6), select_host_v6(&ip6_resources)),
]
},
)
@@ -253,8 +282,7 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())),
udp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains.clone())),
tcp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains)),
udp_packet(packet_source_v4(tunnel_ip4), select(dns_v4_domains)),
]
},
)
@@ -266,11 +294,28 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),),
udp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains.clone()),),
tcp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains),),
udp_packet(packet_source_v6(tunnel_ip6), select(dns_v6_domains),),
]
},
)
.with_if_not_empty(
10,
state.resolved_v4_domains_with_tcp_resources(),
|dns_v4_domains| {
let tunnel_ip4 = state.client.inner().tunnel_ip4;
connect_tcp(Just(tunnel_ip4), select(dns_v4_domains))
},
)
.with_if_not_empty(
10,
state.resolved_v6_domains_with_tcp_resources(),
|dns_v6_domains| {
let tunnel_ip6 = state.client.inner().tunnel_ip6;
connect_tcp(Just(tunnel_ip6), select(dns_v6_domains))
},
)
.with_if_not_empty(
5,
(state.all_domains(), state.reachable_dns_servers()),
@@ -294,10 +339,6 @@ impl ReferenceState {
select(resolved_non_resource_ip4s.clone()),
),
udp_packet(
packet_source_v4(tunnel_ip4),
select(resolved_non_resource_ip4s.clone()),
),
tcp_packet(
packet_source_v4(tunnel_ip4),
select(resolved_non_resource_ip4s),
),
@@ -319,10 +360,6 @@ impl ReferenceState {
select(resolved_non_resource_ip6s.clone()),
),
udp_packet(
packet_source_v6(tunnel_ip6),
select(resolved_non_resource_ip6s.clone()),
),
tcp_packet(
packet_source_v6(tunnel_ip6),
select(resolved_non_resource_ip6s),
),
@@ -335,7 +372,6 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)),
udp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)),
tcp_packet(packet_source_v4(tunnel_ip4), select_host_v4(&gateway_ips)),
]
})
.with_if_not_empty(1, state.connected_gateway_ipv6_ips(), |gateway_ips| {
@@ -344,7 +380,6 @@ impl ReferenceState {
prop_oneof![
icmp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)),
udp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)),
tcp_packet(packet_source_v6(tunnel_ip4), select_host_v6(&gateway_ips)),
]
})
.boxed()
@@ -427,24 +462,14 @@ impl ReferenceState {
)
});
}
Transition::SendTcpPayload {
Transition::ConnectTcp {
src,
dst,
sport,
dport,
payload,
..
} => {
state.client.exec_mut(|client| {
client.on_tcp_packet(
dst.clone(),
*sport,
*dport,
*payload,
|r| state.portal.gateway_for_resource(r).copied(),
|ip| state.portal.gateway_by_ip(ip),
)
});
}
} => state.client.exec_mut(|client| {
client.on_connect_tcp(*src, dst.clone(), *sport, *dport);
}),
Transition::UpdateSystemDnsServers(servers) => {
state
.client
@@ -507,13 +532,18 @@ impl ReferenceState {
true
}
Transition::DisableResources(resources) => {
// Don't disabled resources we don't have.
// It doesn't hurt but makes the logs of reduced testcases weird.
resources
.iter()
.all(|r| state.client.inner().has_resource(*r))
}
Transition::DisableResources(resources) => resources.iter().all(|r| {
let has_resource = state.client.inner().has_resource(*r);
let has_tcp_connection = state
.client
.inner()
.tcp_connection_tuple_to_resource(*r)
.is_some();
// Don't disabled resources we don't have. It doesn't hurt but makes the logs of reduced testcases weird.
// Also don't disable resources where we have TCP connections as those would get interrupted.
has_resource && !has_tcp_connection
}),
Transition::SendIcmpPacket {
src,
dst: Destination::DomainName { name, .. },
@@ -538,17 +568,16 @@ impl ReferenceState {
ref_client.is_valid_udp_packet(sport, dport, payload)
&& state.is_valid_dst_domain(name, src)
}
Transition::SendTcpPayload {
Transition::ConnectTcp {
src,
dst: Destination::DomainName { name, .. },
dst: dst @ Destination::DomainName { name, .. },
sport,
dport,
payload,
} => {
let ref_client = state.client.inner();
ref_client.is_valid_tcp_packet(sport, dport, payload)
&& state.is_valid_dst_domain(name, src)
state.is_valid_dst_domain(name, src)
&& !ref_client.has_tcp_connection(*src, dst.clone(), *sport, *dport)
}
Transition::SendIcmpPacket {
dst: Destination::IpAddr(dst),
@@ -573,16 +602,17 @@ impl ReferenceState {
ref_client.is_valid_udp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst)
}
Transition::SendTcpPayload {
dst: Destination::IpAddr(dst),
Transition::ConnectTcp {
src,
dst: dst @ Destination::IpAddr(dst_ip),
sport,
dport,
payload,
..
} => {
let ref_client = state.client.inner();
ref_client.is_valid_tcp_packet(sport, dport, payload) && state.is_valid_dst_ip(*dst)
state.is_valid_dst_ip(*dst_ip)
&& !ref_client.has_tcp_connection(*src, dst.clone(), *sport, *dport)
}
Transition::UpdateSystemDnsServers(servers) => {
if servers.is_empty() {
@@ -647,7 +677,16 @@ impl ReferenceState {
}
Transition::ReconnectPortal => true,
Transition::DeactivateResource(r) => {
state.client.inner().all_resource_ids().contains(r)
let has_resource = state.client.inner().has_resource(*r);
let has_tcp_connection = state
.client
.inner()
.tcp_connection_tuple_to_resource(*r)
.is_some();
// Don't deactivate resources we don't have. It doesn't hurt but makes the logs of reduced testcases weird.
// Also don't deactivate resources where we have TCP connections as those would get interrupted.
has_resource && !has_tcp_connection
}
Transition::RebootRelaysWhilePartitioned(new_relays)
| Transition::DeployNewRelays(new_relays) => {
@@ -779,6 +818,24 @@ impl ReferenceState {
.collect()
}
fn resolved_v4_domains_with_tcp_resources(&self) -> Vec<DomainName> {
self.client
.inner()
.resolved_v4_domains()
.into_iter()
.filter(|domain| self.tcp_resources.contains_key(domain))
.collect()
}
fn resolved_v6_domains_with_tcp_resources(&self) -> Vec<DomainName> {
self.client
.inner()
.resolved_v6_domains()
.into_iter()
.filter(|domain| self.tcp_resources.contains_key(domain))
.collect()
}
fn deploy_new_relays(&mut self, new_relays: &BTreeMap<RelayId, Host<u64>>) {
// Always take down all relays because we can't know which one was sampled for the connection.
for relay in self.relays.values() {

View File

@@ -13,7 +13,7 @@ use crate::{
messages::{DnsServer, Interface},
};
use bimap::BiMap;
use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, SiteId};
use connlib_model::{ClientId, GatewayId, RelayId, ResourceId, ResourceStatus, Site, SiteId};
use dns_types::{DomainName, Query, RecordData, RecordType};
use ip_network::{IpNetwork, Ipv4Network, Ipv6Network};
use ip_network_table::IpNetworkTable;
@@ -59,13 +59,14 @@ pub(crate) struct SimClient {
pub(crate) sent_icmp_requests: HashMap<(Seq, Identifier), IpPacket>,
pub(crate) received_icmp_replies: BTreeMap<(Seq, Identifier), IpPacket>,
pub(crate) sent_tcp_requests: HashMap<(SPort, DPort), IpPacket>,
pub(crate) received_tcp_replies: BTreeMap<(SPort, DPort), IpPacket>,
pub(crate) sent_udp_requests: HashMap<(SPort, DPort), IpPacket>,
pub(crate) received_udp_replies: BTreeMap<(SPort, DPort), IpPacket>,
pub(crate) tcp_dns_client: dns_over_tcp::Client,
/// TCP connections to resources.
pub(crate) tcp_client: crate::tests::tcp::Client,
pub(crate) failed_tcp_packets: BTreeMap<(SPort, DPort), IpPacket>,
}
impl SimClient {
@@ -84,8 +85,6 @@ impl SimClient {
received_tcp_dns_responses: Default::default(),
sent_icmp_requests: Default::default(),
received_icmp_replies: Default::default(),
sent_tcp_requests: Default::default(),
received_tcp_replies: Default::default(),
sent_udp_requests: Default::default(),
received_udp_replies: Default::default(),
ipv4_routes: Default::default(),
@@ -93,6 +92,8 @@ impl SimClient {
search_domain: Default::default(),
resource_status: Default::default(),
tcp_dns_client,
tcp_client: crate::tests::tcp::Client::new(now),
failed_tcp_packets: Default::default(),
}
}
@@ -163,6 +164,15 @@ impl SimClient {
}
}
pub fn connect_tcp(&mut self, src: IpAddr, dst: IpAddr, sport: SPort, dport: DPort) {
let local = SocketAddr::new(src, sport.0);
let remote = SocketAddr::new(dst, dport.0);
if let Err(e) = self.tcp_client.connect(local, remote) {
tracing::error!("TCP connect failed: {e:#}")
}
}
pub(crate) fn encapsulate(
&mut self,
packet: IpPacket,
@@ -178,6 +188,21 @@ impl SimClient {
Some(transmit)
}
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.tcp_dns_client
.poll_outbound()
.or_else(|| self.tcp_client.poll_outbound())
}
pub fn handle_timeout(&mut self, now: Instant) {
self.tcp_dns_client.handle_timeout(now);
self.tcp_client.handle_timeout(now);
if self.sut.poll_timeout().is_some_and(|t| t <= now) {
self.sut.handle_timeout(now)
}
}
fn update_sent_requests(&mut self, packet: &IpPacket) {
if let Some(icmp) = packet.as_icmpv4() {
if let Icmpv4Type::EchoRequest(echo) = icmp.icmp_type() {
@@ -195,24 +220,12 @@ impl SimClient {
}
}
if let Some(tcp) = packet.as_tcp() {
self.sent_tcp_requests.insert(
(SPort(tcp.source_port()), DPort(tcp.destination_port())),
packet.clone(),
);
return;
}
if let Some(udp) = packet.as_udp() {
self.sent_udp_requests.insert(
(SPort(udp.source_port()), DPort(udp.destination_port())),
packet.clone(),
);
return;
}
tracing::error!("Sent a request with an unknown transport protocol");
}
pub(crate) fn receive(&mut self, transmit: Transmit, now: Instant) {
@@ -239,8 +252,11 @@ impl SimClient {
.insert((SPort(dst), DPort(src)), packet);
}
Layer4Protocol::Tcp { src, dst } => {
self.received_tcp_replies
.insert((SPort(dst), DPort(src)), packet);
self.failed_tcp_packets
.insert((SPort(src), DPort(dst)), packet.clone());
// Allow the client to process the ICMP error.
self.tcp_client.handle_inbound(packet);
}
Layer4Protocol::Icmp { seq, id } => {
self.received_icmp_replies
@@ -290,11 +306,8 @@ impl SimClient {
return;
}
if let Some(tcp) = packet.as_tcp() {
self.received_tcp_replies.insert(
(SPort(tcp.source_port()), DPort(tcp.destination_port())),
packet.clone(),
);
if self.tcp_client.accepts(&packet) {
self.tcp_client.handle_inbound(packet);
return;
}
@@ -438,10 +451,9 @@ pub struct RefClient {
pub(crate) expected_udp_handshakes:
BTreeMap<GatewayId, BTreeMap<u64, (Destination, SPort, DPort)>>,
/// The expected TCP exchanges.
/// The expected TCP connections.
#[debug(skip)]
pub(crate) expected_tcp_exchanges:
BTreeMap<GatewayId, BTreeMap<u64, (Destination, SPort, DPort)>>,
pub(crate) expected_tcp_connections: HashMap<(IpAddr, Destination, SPort, DPort), ResourceId>,
/// The expected UDP DNS handshakes.
#[debug(skip)]
@@ -582,8 +594,28 @@ impl RefClient {
}
}
pub(crate) fn expected_resource_status(&self) -> BTreeMap<ResourceId, ResourceStatus> {
self.resources
#[expect(
clippy::disallowed_methods,
reason = "We don't care about the ordering of the expected TCP connections."
)]
pub(crate) fn expected_resource_status(
&self,
has_failed_tcp_connection: impl Fn((SPort, DPort)) -> bool,
) -> (BTreeMap<ResourceId, ResourceStatus>, BTreeSet<ResourceId>) {
let maybe_online_sites = self
.expected_tcp_connections
.iter()
.filter(|((_, _, sport, dport), _)| !has_failed_tcp_connection((*sport, *dport)))
.filter_map(|(_, resource)| self.site_for_resource(*resource))
.flat_map(|site| {
self.resources
.iter()
.filter_map(move |r| r.sites().contains(&site).then_some(r.id()))
})
.collect();
let resource_status = self
.resources
.iter()
.filter_map(|r| {
let status = self
@@ -594,7 +626,9 @@ impl RefClient {
Some((r.id(), status))
})
.collect()
.collect();
(resource_status, maybe_online_sites)
}
pub(crate) fn tunnel_ip_for(&self, dst: IpAddr) -> IpAddr {
@@ -642,25 +676,6 @@ impl RefClient {
);
}
pub(crate) fn on_tcp_packet(
&mut self,
dst: Destination,
sport: SPort,
dport: DPort,
payload: u64,
gateway_by_resource: impl Fn(ResourceId) -> Option<GatewayId>,
gateway_by_ip: impl Fn(IpAddr) -> Option<GatewayId>,
) {
self.on_packet(
dst.clone(),
(dst, sport, dport),
|ref_client| &mut ref_client.expected_tcp_exchanges,
payload,
gateway_by_resource,
gateway_by_ip,
);
}
#[tracing::instrument(level = "debug", skip_all, fields(dst, resource, gateway))]
fn on_packet<E>(
&mut self,
@@ -708,6 +723,25 @@ impl RefClient {
.insert(payload, packet_id);
}
pub(crate) fn on_connect_tcp(
&mut self,
src: IpAddr,
dst: Destination,
sport: SPort,
dport: DPort,
) {
let Some(resource) = self.resource_by_dst(&dst) else {
tracing::warn!("Unknown resource");
return;
};
self.connect_to_resource(resource, dst.clone());
self.set_resource_online(resource);
self.expected_tcp_connections
.insert((src, dst, sport, dport), resource);
}
fn connect_to_resource(&mut self, resource: ResourceId, destination: Destination) {
match destination {
Destination::DomainName { .. } => {}
@@ -716,11 +750,7 @@ impl RefClient {
}
fn set_resource_online(&mut self, resource: ResourceId) {
let Some(Ok(site)) = self
.resources
.iter()
.find_map(|r| (r.id() == resource).then_some(r.site()))
else {
let Some(site) = self.site_for_resource(resource) else {
tracing::error!(%resource, "Unknown resource or multi-site resource");
return;
};
@@ -801,6 +831,17 @@ impl RefClient {
self.connected_cidr_resources.contains(&id)
}
fn site_for_resource(&self, resource: ResourceId) -> Option<Site> {
let site = self
.resources
.iter()
.find_map(|r| (r.id() == resource).then_some(r.site()))?
.ok()?
.clone();
Some(site)
}
pub(crate) fn active_internet_resource(&self) -> Option<ResourceId> {
self.internet_resource
.filter(|r| !self.disabled_resources.contains(r))
@@ -866,15 +907,6 @@ impl RefClient {
)
}
/// An TCP packet is valid if we didn't yet send an TCP packet with the same sport, dport and payload.
pub(crate) fn is_valid_tcp_packet(&self, sport: &SPort, dport: &DPort, payload: &u64) -> bool {
self.expected_tcp_exchanges.values().flatten().all(
|(existig_payload, (_, existing_sport, existing_dport))| {
existing_dport != dport && existing_sport != sport && existig_payload != payload
},
)
}
pub(crate) fn resolved_v4_domains(&self) -> Vec<DomainName> {
self.resolved_domains()
.filter_map(|(domain, records)| {
@@ -1053,6 +1085,30 @@ impl RefClient {
pub(crate) fn upstream_dns_resolvers(&self) -> Vec<DnsServer> {
self.upstream_dns_resolvers.clone()
}
pub(crate) fn has_tcp_connection(
&self,
src: IpAddr,
dst: Destination,
sport: SPort,
dport: DPort,
) -> bool {
self.expected_tcp_connections
.contains_key(&(src, dst, sport, dport))
}
#[expect(
clippy::disallowed_methods,
reason = "Iteration order does not matter here."
)]
pub(crate) fn tcp_connection_tuple_to_resource(
&self,
resource: ResourceId,
) -> Option<(SPort, DPort)> {
self.expected_tcp_connections
.iter()
.find_map(|((_, _, sport, dport), res)| (resource == *res).then_some((*sport, *dport)))
}
}
// This function only works on the tests because we are limited to resources with a single wildcard at the beginning of the resource.
@@ -1138,7 +1194,7 @@ fn ref_client(
connected_internet_resource: Default::default(),
expected_icmp_handshakes: Default::default(),
expected_udp_handshakes: Default::default(),
expected_tcp_exchanges: Default::default(),
expected_tcp_connections: Default::default(),
expected_udp_dns_handshakes: Default::default(),
expected_tcp_dns_handshakes: Default::default(),
disabled_resources: Default::default(),

View File

@@ -15,7 +15,7 @@ use ip_packet::{IcmpEchoHeader, Icmpv4Type, Icmpv6Type, IpPacket};
use proptest::prelude::*;
use snownet::Transmit;
use std::{
collections::BTreeMap,
collections::{BTreeMap, BTreeSet},
iter,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
time::Instant,
@@ -32,19 +32,20 @@ pub(crate) struct SimGateway {
/// The received UDP packets, indexed by our custom UDP payload.
pub(crate) received_udp_requests: BTreeMap<u64, IpPacket>,
/// The received TCP packets, indexed by our custom TCP payload.
pub(crate) received_tcp_requests: BTreeMap<u64, IpPacket>,
site_specific_dns_records: DnsRecords,
udp_dns_server_resources: BTreeMap<SocketAddr, UdpDnsServerResource>,
tcp_dns_server_resources: BTreeMap<SocketAddr, TcpDnsServerResource>,
tcp_resources: BTreeMap<SocketAddr, crate::tests::tcp::Server>,
}
impl SimGateway {
pub(crate) fn new(
id: GatewayId,
sut: GatewayState,
tcp_resources: BTreeSet<SocketAddr>,
site_specific_dns_records: DnsRecords,
now: Instant,
) -> Self {
Self {
id,
@@ -54,7 +55,17 @@ impl SimGateway {
udp_dns_server_resources: Default::default(),
tcp_dns_server_resources: Default::default(),
received_udp_requests: Default::default(),
received_tcp_requests: Default::default(),
tcp_resources: tcp_resources
.into_iter()
.map(|address| {
let mut server = crate::tests::tcp::Server::new(now);
if let Err(e) = server.listen(address) {
tracing::error!(%address, "Failed to listen on address: {e}")
}
(address, server)
})
.collect(),
}
}
@@ -113,9 +124,15 @@ impl SimGateway {
std::iter::from_fn(|| server.poll_outbound())
});
let tcp_resource_packets = self.tcp_resources.values_mut().flat_map(|server| {
server.handle_timeout(now);
std::iter::from_fn(|| server.poll_outbound())
});
udp_server_packets
.chain(tcp_server_packets)
.chain(tcp_resource_packets)
.filter_map(|packet| self.sut.handle_tun_input(packet, now).unwrap())
.collect()
}
@@ -203,6 +220,11 @@ impl SimGateway {
if let Some(tcp) = packet.as_tcp() {
let socket = SocketAddr::new(dst_ip, tcp.destination_port());
if let Some(server) = self.tcp_resources.get_mut(&socket) {
server.handle_inbound(packet);
return None;
}
// NOTE: we can make this assumption because port 53 is excluded from non-dns query packets
if let Some(server) = self.tcp_dns_server_resources.get_mut(&socket) {
server.handle_input(packet);
@@ -240,12 +262,6 @@ impl SimGateway {
tracing::debug!(%packet_id, "Received UDP request");
self.received_udp_requests.insert(packet_id, packet.clone());
}
if let Some(tcp) = packet.as_tcp() {
let packet_id = u64::from_be_bytes(*tcp.payload().first_chunk().unwrap());
tracing::debug!(%packet_id, "Received TCP request");
self.received_tcp_requests.insert(packet_id, packet.clone());
}
}
fn handle_icmp_request(
@@ -287,14 +303,19 @@ impl RefGateway {
/// Initialize the [`GatewayState`].
///
/// This simulates receiving the `init` message from the portal.
pub(crate) fn init(self, id: GatewayId, now: Instant) -> SimGateway {
pub(crate) fn init(
self,
id: GatewayId,
tcp_resources: BTreeSet<SocketAddr>,
now: Instant,
) -> SimGateway {
let mut sut = GatewayState::new(self.key.0, now); // Cheating a bit here by reusing the key as seed.
sut.update_tun_device(IpConfig {
v4: self.tunnel_ip4,
v6: self.tunnel_ip6,
});
SimGateway::new(id, sut, self.site_specific_dns_records)
SimGateway::new(id, sut, tcp_resources, self.site_specific_dns_records, now)
}
pub fn dns_records(&self) -> &DnsRecords {

View File

@@ -1,4 +1,5 @@
use super::dns_records::DnsRecords;
use super::unreachable_hosts::UnreachableHosts;
use super::{sim_net::Host, sim_relay::ref_relay_host, stub_portal::StubPortal};
use crate::client::{
CidrResource, DNS_SENTINELS_V4, DNS_SENTINELS_V6, DnsResource, IPV4_RESOURCES, IPV6_RESOURCES,
@@ -7,7 +8,7 @@ use crate::client::{
use crate::messages::DnsServer;
use crate::{IPV4_TUNNEL, IPV6_TUNNEL, proptest::*};
use connlib_model::{RelayId, Site};
use dns_types::OwnedRecordData;
use dns_types::{DomainName, OwnedRecordData};
use ip_network::{Ipv4Network, Ipv6Network};
use itertools::Itertools;
use prop::sample;
@@ -148,6 +149,48 @@ pub(crate) fn stub_portal() -> impl Strategy<Value = StubPortal> {
)
}
/// Samples a list of TCP resource addresses from the given DNS records.
///
/// We sample at most 1 domain from the given records and create a [`SocketAddr`]
/// for _each_ IP that this domain resolves this.
/// This is equivalent for how one would deploy a service in the real world.
/// If `example.com` resolves to 4 IPs, an HTTP server needs to run on all 4 IPs on the same port.
///
/// The port is sampled together with domain.
pub(crate) fn tcp_resources(
dns_records: DnsRecords,
unreachable_hosts: UnreachableHosts,
) -> impl Strategy<Value = BTreeMap<DomainName, BTreeSet<SocketAddr>>> {
let all_domains = dns_records.domains_iter().collect::<Vec<_>>();
collection::btree_set(
(sample::select(all_domains.clone()), any::<u16>()),
1..=all_domains.len(),
)
.prop_map(move |domains| {
domains
.into_iter()
.filter(|(domain, _)| {
dns_records
.domain_ips_iter(domain)
.all(|ip| !unreachable_hosts.is_unreachable(ip))
})
.map({
let dns_records = dns_records.clone();
move |(domain, port)| {
let addresses = dns_records
.domain_ips_iter(&domain)
.map(|address| SocketAddr::new(address, port))
.collect::<BTreeSet<_>>();
(domain, addresses)
}
})
.collect()
})
}
fn create_internet_site(mut sites: BTreeSet<Site>) -> (Site, BTreeSet<Site>) {
// Rebrand the first site as the Internet site. That way, we can guarantee to always have one.
let mut internet_site = sites.pop_first().unwrap();

View File

@@ -20,11 +20,10 @@ use bufferpool::BufferPool;
use connlib_model::{ClientId, GatewayId, PublicKey, RelayId};
use dns_types::ResponseCode;
use dns_types::prelude::*;
use ip_packet::make::TcpFlags;
use rand::SeedableRng;
use rand::distributions::DistString;
use sha2::Digest;
use snownet::Transmit;
use snownet::{NoTurnServers, Transmit};
use std::iter;
use std::{
collections::BTreeMap,
@@ -63,7 +62,18 @@ impl TunnelTest {
.iter()
.map(|(gid, gateway)| {
let gateway = gateway.map(
|ref_gateway, _, _| ref_gateway.init(*gid, flux_capacitor.now()),
|ref_gateway, _, _| {
ref_gateway.init(
*gid,
ref_state
.tcp_resources
.values()
.flatten()
.copied()
.collect(),
flux_capacitor.now(),
)
},
debug_span!("gateway", %gid),
);
@@ -186,28 +196,17 @@ impl TunnelTest {
buffered_transmits.push_from(transmit, &state.client, now);
}
Transition::SendTcpPayload {
Transition::ConnectTcp {
src,
dst,
sport,
dport,
payload,
} => {
let dst = address_from_destination(&dst, &state, &src);
let packet = ip_packet::make::tcp_packet(
src,
dst,
sport.0,
dport.0,
TcpFlags::default(),
payload.to_be_bytes().to_vec(),
)
.unwrap();
let transmit = state.client.exec_mut(|sim| sim.encapsulate(packet, now));
buffered_transmits.push_from(transmit, &state.client, now);
state
.client
.exec_mut(|sim| sim.connect_tcp(src, dst, sport, dport));
}
Transition::SendDnsQueries(queries) => {
for DnsQuery {
@@ -360,12 +359,7 @@ impl TunnelTest {
&sim_gateways,
&ref_state.global_dns_records,
);
assert_tcp_packets_properties(
ref_client,
sim_client,
&sim_gateways,
&ref_state.global_dns_records,
);
assert_tcp_connections(ref_client, sim_client);
assert_udp_dns_packets_properties(ref_client, sim_client);
assert_tcp_dns(ref_client, sim_client);
assert_dns_servers_are_valid(ref_client, sim_client);
@@ -422,7 +416,21 @@ impl TunnelTest {
}
if let Some(event) = self.client.exec_mut(|c| c.sut.poll_event()) {
self.on_client_event(self.client.inner().id, event, &ref_state.portal);
match self.on_client_event(self.client.inner().id, event, &ref_state.portal) {
Ok(()) => {}
Err(AuthorizeFlowError::Client(_)) => {
self.client.exec_mut(|c| {
c.update_relays(iter::empty(), self.relays.iter(), now);
});
}
Err(AuthorizeFlowError::Gateway(_)) => {
for gateway in self.gateways.values_mut() {
gateway.exec_mut(|g| {
g.update_relays(iter::empty(), self.relays.iter(), now)
});
}
}
};
continue;
}
if let Some(query) = self.client.exec_mut(|c| c.sut.poll_dns_queries()) {
@@ -536,8 +544,6 @@ impl TunnelTest {
) {
// Handle the TCP DNS client, i.e. simulate applications making TCP DNS queries.
self.client.exec_mut(|c| {
c.tcp_dns_client.handle_timeout(now);
while let Some(result) = c.tcp_dns_client.poll_query_result() {
match result.result {
Ok(message) => {
@@ -554,21 +560,17 @@ impl TunnelTest {
}
});
while let Some(transmit) = self.client.exec_mut(|c| {
let packet = c.tcp_dns_client.poll_outbound()?;
let packet = c.poll_outbound()?;
c.encapsulate(packet, now)
}) {
buffered_transmits.push_from(transmit, &self.client, now)
}
self.client.exec_mut(|c| c.handle_timeout(now));
// Handle the client's `Transmit`s and timeout.
// Handle the client's `Transmit`s.
while let Some(transmit) = self.client.poll_inbox(now) {
self.client.exec_mut(|c| c.receive(transmit, now))
}
self.client.exec_mut(|c| {
if c.sut.poll_timeout().is_some_and(|t| t <= now) {
c.sut.handle_timeout(now)
}
});
// Handle all gateway `Transmit`s and timeouts.
for (_, gateway) in self.gateways.iter_mut() {
@@ -680,7 +682,12 @@ impl TunnelTest {
}
}
fn on_client_event(&mut self, src: ClientId, event: ClientEvent, portal: &StubPortal) {
fn on_client_event(
&mut self,
src: ClientId,
event: ClientEvent,
portal: &StubPortal,
) -> Result<(), AuthorizeFlowError> {
let now = self.flux_capacitor.now();
match event {
@@ -694,7 +701,9 @@ impl TunnelTest {
for candidate in candidates {
g.sut.add_ice_candidate(src, candidate, now)
}
})
});
Ok(())
}
ClientEvent::RemovedIceCandidates {
candidates,
@@ -706,7 +715,9 @@ impl TunnelTest {
for candidate in candidates {
g.sut.remove_ice_candidate(src, candidate, now)
}
})
});
Ok(())
}
ClientEvent::ConnectionIntent {
resource: resource_id,
@@ -736,22 +747,29 @@ impl TunnelTest {
now,
)
})
.unwrap();
if let Err(e) = self.client.exec_mut(|c| {
c.sut.handle_flow_created(
resource_id,
gateway_id,
gateway_key,
gateway.inner().sut.tunnel_ip_config().unwrap(),
site_id,
preshared_key,
client_ice,
gateway_ice,
now,
)
}) {
tracing::error!("{e:#}")
};
.map_err(AuthorizeFlowError::Gateway)?;
self.client
.exec_mut(|c| {
c.sut.handle_flow_created(
resource_id,
gateway_id,
gateway_key,
gateway.inner().sut.tunnel_ip_config().unwrap(),
site_id,
preshared_key,
client_ice,
gateway_ice,
now,
)
})
.unwrap_or_else(|e| {
tracing::error!("{e:#}");
Ok(())
})
.map_err(AuthorizeFlowError::Client)?;
Ok(())
}
ClientEvent::ResourcesChanged { resources } => {
@@ -761,6 +779,8 @@ impl TunnelTest {
.map(|r| (r.id(), r.status()))
.collect();
});
Ok(())
}
ClientEvent::TunInterfaceUpdated(config) => {
if self.client.inner().dns_mapping() == &config.dns_by_sentinel
@@ -788,6 +808,8 @@ impl TunnelTest {
c.ipv6_routes = config.ipv6_routes;
c.search_domain = config.search_domain
});
Ok(())
}
}
}
@@ -843,6 +865,11 @@ impl TunnelTest {
}
}
enum AuthorizeFlowError {
Client(NoTurnServers),
Gateway(NoTurnServers),
}
fn address_from_destination(destination: &Destination, state: &TunnelTest, src: &IpAddr) -> IpAddr {
match destination {
Destination::DomainName { resolved_ip, name } => {

View File

@@ -0,0 +1,165 @@
use std::{
collections::BTreeMap,
net::SocketAddr,
time::{Duration, Instant},
};
use anyhow::{Context, Result};
use ip_packet::{IpPacket, Layer4Protocol};
use l3_tcp::Socket;
pub struct Client {
sockets: l3_tcp::SocketSet<'static>,
sockets_by_remote: BTreeMap<SocketAddr, l3_tcp::SocketHandle>,
device: l3_tcp::InMemoryDevice,
interface: l3_tcp::Interface,
created_at: Instant,
last_now: Instant,
}
pub struct Server {
sockets: l3_tcp::SocketSet<'static>,
listen_endpoints: BTreeMap<l3_tcp::SocketHandle, SocketAddr>,
device: l3_tcp::InMemoryDevice,
interface: l3_tcp::Interface,
created_at: Instant,
last_now: Instant,
}
impl Client {
pub fn new(now: Instant) -> Self {
let mut device = l3_tcp::InMemoryDevice::default();
let interface = l3_tcp::create_interface(&mut device);
Self {
sockets: l3_tcp::SocketSet::new(Vec::default()),
sockets_by_remote: Default::default(),
device,
interface,
created_at: now,
last_now: now,
}
}
pub fn connect(&mut self, local: SocketAddr, remote: SocketAddr) -> Result<()> {
anyhow::ensure!(!self.sockets_by_remote.contains_key(&remote));
let mut socket = l3_tcp::create_tcp_socket();
socket
.connect(self.interface.context(), remote, local)
.context("Failed to create TCP connection")?;
// A short keep-alive ensures we detect broken connections.
socket.set_keep_alive(Some(l3_tcp::Duration::from_secs(5)));
// 30s is a common timeout for TCP connections.
socket.set_timeout(Some(l3_tcp::Duration::from_secs(30)));
let handle = self.sockets.add(socket);
self.sockets_by_remote.insert(remote, handle);
Ok(())
}
pub fn accepts(&self, packet: &IpPacket) -> bool {
let Some(tcp) = packet.as_tcp() else {
return false;
};
self.sockets_by_remote
.contains_key(&SocketAddr::new(packet.source(), tcp.source_port()))
}
pub fn handle_inbound(&mut self, packet: IpPacket) {
// TODO: Upstream ICMP error handling to `smoltcp`.
if let Ok(Some((failed_packet, _))) = packet.icmp_unreachable_destination() {
if let Layer4Protocol::Tcp { dst, .. } = failed_packet.layer4_protocol() {
if let Some(handle) = self
.sockets_by_remote
.get(&SocketAddr::new(failed_packet.dst(), dst))
{
self.sockets.get_mut::<l3_tcp::Socket>(*handle).abort();
}
}
}
self.device.receive(packet);
}
pub fn handle_timeout(&mut self, now: Instant) {
self.last_now = now;
let _result = self.interface.poll(
l3_tcp::now(self.created_at, now),
&mut self.device,
&mut self.sockets,
);
}
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.device.next_send()
}
pub fn _poll_timeout(&mut self) -> Option<Instant> {
let now = l3_tcp::now(self.created_at, self.last_now);
let poll_in = self.interface.poll_delay(now, &self.sockets)?;
Some(self.last_now + Duration::from(poll_in))
}
pub fn iter_sockets(&self) -> impl Iterator<Item = &Socket> {
self.sockets.iter().map(|(_, s)| match s {
l3_tcp::AnySocket::Tcp(socket) => socket,
})
}
}
impl Server {
pub fn new(now: Instant) -> Self {
let mut device = l3_tcp::InMemoryDevice::default();
let interface = l3_tcp::create_interface(&mut device);
Self {
sockets: l3_tcp::SocketSet::new(Vec::default()),
listen_endpoints: Default::default(),
device,
interface,
created_at: now,
last_now: now,
}
}
pub fn listen(&mut self, address: SocketAddr) -> Result<()> {
let mut socket = l3_tcp::create_tcp_socket();
socket
.listen(address)
.with_context(|| format!("Failed to listen on {address}"))?;
let handle = self.sockets.add(socket);
self.listen_endpoints.insert(handle, address);
Ok(())
}
pub fn handle_inbound(&mut self, packet: IpPacket) {
self.device.receive(packet);
}
pub fn handle_timeout(&mut self, now: Instant) {
self.last_now = now;
let _result = self.interface.poll(
l3_tcp::now(self.created_at, now),
&mut self.device,
&mut self.sockets,
);
}
pub fn poll_outbound(&mut self) -> Option<IpPacket> {
self.device.next_send()
}
}

View File

@@ -12,6 +12,7 @@ use proptest::{prelude::*, sample};
use std::{
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
num::NonZeroU16,
};
/// The possible transitions of the state machine.
@@ -39,13 +40,12 @@ pub(crate) enum Transition {
dport: DPort,
payload: u64,
},
/// Send an TCP payload to destination (IP resource, DNS resource or IP non-resource).
SendTcpPayload {
ConnectTcp {
src: IpAddr,
dst: Destination,
sport: SPort,
dport: DPort,
payload: u64,
},
/// Send a DNS query.
@@ -125,6 +125,29 @@ pub(crate) enum Destination {
IpAddr(IpAddr),
}
impl Eq for Destination {}
impl std::hash::Hash for Destination {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Destination::DomainName { name, .. } => name.hash(state),
Destination::IpAddr(ip_addr) => ip_addr.hash(state),
}
}
}
impl PartialEq for Destination {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::DomainName { name: l_name, .. }, Self::DomainName { name: r_name, .. }) => {
l_name == r_name
}
(Self::IpAddr(l0), Self::IpAddr(r0)) => l0 == r0,
_ => false,
}
}
}
impl Destination {
pub(crate) fn ip_addr(&self) -> Option<IpAddr> {
match self {
@@ -246,32 +269,28 @@ where
)
}
#[expect(private_bounds)]
pub(crate) fn tcp_packet<I, D>(
pub(crate) fn connect_tcp<I>(
src: impl Strategy<Value = I>,
dst: impl Strategy<Value = D>,
dst: impl Strategy<Value = DomainName>,
) -> impl Strategy<Value = Transition>
where
I: Into<IpAddr>,
D: Into<PacketDestination>,
{
(
src.prop_map(Into::into),
dst.prop_map(Into::into),
any::<u16>(),
non_dns_ports(),
dst,
any::<NonZeroU16>().prop_map(|p| p.get()),
non_dns_ports().prop_filter("avoid zero port", |p| *p != 0),
any::<sample::Selector>(),
any::<u64>(),
)
.prop_map(|(src, dst, sport, dport, resolved_ip, payload)| {
Transition::SendTcpPayload {
.prop_map(
|(src, name, sport, dport, resolved_ip)| Transition::ConnectTcp {
src,
dst: dst.into_destination(resolved_ip),
dst: Destination::DomainName { resolved_ip, name },
sport: SPort(sport),
dport: DPort(dport),
payload,
}
})
},
)
}
fn non_dns_ports() -> impl Strategy<Value = u16> {

View File

@@ -16,6 +16,10 @@ impl UnreachableHosts {
pub(crate) fn icmp_error_for_ip(&self, ip: IpAddr) -> Option<IcmpError> {
self.inner.get(&ip).copied()
}
pub(crate) fn is_unreachable(&self, ip: IpAddr) -> bool {
self.inner.contains_key(&ip)
}
}
/// Samples a subset of the provided DNS records which we will treat as "unreachable".