diff --git a/rust/connlib/tunnel/src/gateway.rs b/rust/connlib/tunnel/src/gateway.rs index 333a57abb..b7a737e12 100644 --- a/rust/connlib/tunnel/src/gateway.rs +++ b/rust/connlib/tunnel/src/gateway.rs @@ -475,6 +475,7 @@ impl GatewayState { inner_dst_ip = %flow.inner_dst_ip, inner_src_port = %flow.inner_src_port, inner_dst_port = %flow.inner_dst_port, + inner_domain = flow.inner_domain.map(tracing::field::display), outer_src_ip = %flow.outer_src_ip, outer_dst_ip = %flow.outer_dst_ip, @@ -502,6 +503,7 @@ impl GatewayState { inner_dst_ip = %flow.inner_dst_ip, inner_src_port = %flow.inner_src_port, inner_dst_port = %flow.inner_dst_port, + inner_domain = flow.inner_domain.map(tracing::field::display), outer_src_ip = %flow.outer_src_ip, outer_dst_ip = %flow.outer_dst_ip, diff --git a/rust/connlib/tunnel/src/gateway/client_on_gateway.rs b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs index 7f0eca9a8..60054facf 100644 --- a/rust/connlib/tunnel/src/gateway/client_on_gateway.rs +++ b/rust/connlib/tunnel/src/gateway/client_on_gateway.rs @@ -439,6 +439,8 @@ impl ClientOnGateway { )); } + flow_tracker::inbound_wg::record_domain(state.domain.clone()); + let (source_protocol, real_ip) = self.nat_table .translate_outgoing(&packet, resolved_ip, now)?; diff --git a/rust/connlib/tunnel/src/gateway/flow_tracker.rs b/rust/connlib/tunnel/src/gateway/flow_tracker.rs index a2474fb9e..766ee172d 100644 --- a/rust/connlib/tunnel/src/gateway/flow_tracker.rs +++ b/rust/connlib/tunnel/src/gateway/flow_tracker.rs @@ -7,6 +7,7 @@ use std::{ use chrono::{DateTime, TimeDelta, Utc}; use connlib_model::{ClientId, ResourceId}; +use dns_types::DomainName; use ip_packet::{IcmpError, IpPacket, Protocol, UnsupportedProtocol}; use std::time::Instant; @@ -91,6 +92,7 @@ impl FlowTracker { client: None, resource: None, icmp_error: None, + domain: None, }))); debug_assert!( current.is_none(), @@ -161,6 +163,7 @@ impl FlowTracker { }), client: Some(client), resource: Some(resource), + domain, icmp_error: _, // TODO: What to do with ICMP errors? } = flow else { @@ -203,6 +206,7 @@ impl FlowTracker { context, fin_tx: false, fin_rx: false, + domain, }); } hash_map::Entry::Occupied(occupied) if occupied.get().context != context => { @@ -228,6 +232,7 @@ impl FlowTracker { context, fin_tx: false, fin_rx: false, + domain, }, ); } @@ -249,6 +254,7 @@ impl FlowTracker { context, fin_tx: false, fin_rx: false, + domain, }, ); } @@ -291,6 +297,7 @@ impl FlowTracker { last_packet: now_utc, stats: FlowStats::default().with_tx(payload_len as u64), context, + domain, }); } hash_map::Entry::Occupied(occupied) if occupied.get().context != context => { @@ -314,6 +321,7 @@ impl FlowTracker { last_packet: now_utc, stats: FlowStats::default().with_tx(payload_len as u64), context, + domain, }, ); } @@ -449,6 +457,7 @@ pub struct CompletedTcpFlow { pub inner_dst_ip: IpAddr, pub inner_src_port: u16, pub inner_dst_port: u16, + pub inner_domain: Option, pub outer_src_ip: IpAddr, pub outer_dst_ip: IpAddr, @@ -473,6 +482,7 @@ pub struct CompletedUdpFlow { pub inner_dst_ip: IpAddr, pub inner_src_port: u16, pub inner_dst_port: u16, + pub inner_domain: Option, pub outer_src_ip: IpAddr, pub outer_dst_ip: IpAddr, @@ -497,6 +507,7 @@ impl CompletedTcpFlow { inner_dst_ip: key.dst_ip, inner_src_port: key.src_port, inner_dst_port: key.dst_port, + inner_domain: value.domain, outer_src_ip: value.context.src_ip, outer_dst_ip: value.context.dst_ip, outer_src_port: value.context.src_port, @@ -521,6 +532,7 @@ impl CompletedUdpFlow { inner_dst_ip: key.dst_ip, inner_src_port: key.src_port, inner_dst_port: key.dst_port, + inner_domain: value.domain, outer_src_ip: value.context.src_ip, outer_dst_ip: value.context.dst_ip, outer_src_port: value.context.src_port, @@ -560,6 +572,8 @@ struct TcpFlowValue { stats: FlowStats, context: FlowContext, + domain: Option, + fin_tx: bool, fin_rx: bool, } @@ -570,6 +584,8 @@ struct UdpFlowValue { last_packet: DateTime, stats: FlowStats, context: FlowContext, + + domain: Option, } #[derive(Debug, Default)] @@ -660,6 +676,8 @@ impl std::fmt::Debug for FlowContextDiff { } pub mod inbound_wg { + use dns_types::DomainName; + use super::*; pub fn record_client(cid: ClientId) { @@ -670,6 +688,10 @@ pub mod inbound_wg { update_current_flow_inbound_wireguard(|wg| wg.resource.replace(rid)); } + pub fn record_domain(name: DomainName) { + update_current_flow_inbound_wireguard(|wg| wg.domain.replace(name)); + } + pub fn record_decrypted_packet(packet: &IpPacket) { update_current_flow_inbound_wireguard(|wg| { wg.inner = Some(InnerFlow::from(packet)); @@ -762,6 +784,8 @@ struct InboundWireGuard { inner: Option, client: Option, resource: Option, + /// The domain name in case this packet is for a DNS resource. + domain: Option, icmp_error: Option, } diff --git a/scripts/tests/download-concurrent.sh b/scripts/tests/download-concurrent.sh index c80cd0ef4..aed7b9163 100755 --- a/scripts/tests/download-concurrent.sh +++ b/scripts/tests/download-concurrent.sh @@ -35,6 +35,7 @@ rx_bytes=0 for flow in "${flows[@]}"; do assert_eq "$(get_flow_field "$flow" "inner_dst_ip")" "172.21.0.101" + assert_eq "$(get_flow_field "$flow" "inner_domain")" "download.httpbin" rx_bytes+="$(get_flow_field "$flow" "rx_bytes")" done diff --git a/scripts/tests/download.sh b/scripts/tests/download.sh index fc8cd9e05..6fddc4ce4 100755 --- a/scripts/tests/download.sh +++ b/scripts/tests/download.sh @@ -26,4 +26,5 @@ assert_eq "${#flows[@]}" 1 flow="${flows[0]}" assert_eq "$(get_flow_field "$flow" "inner_dst_ip")" "172.21.0.101" +assert_eq "$(get_flow_field "$flow" "inner_domain")" "download.httpbin" assert_gteq "$(get_flow_field "$flow" "rx_bytes")" 10000000